You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/11/12 14:39:19 UTC

svn commit: r1714045 - in /labs/yay/trunk: ./ core/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/utils/ core/src/test/java/org/apache/yay/core/

Author: tommaso
Date: Thu Nov 12 13:39:18 2015
New Revision: 1714045

URL: http://svn.apache.org/viewvc?rev=1714045&view=rev
Log:
performance improvements on training set creation and softmax

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java   (with props)
Modified:
    labs/yay/trunk/core/pom.xml
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
    labs/yay/trunk/pom.xml

Modified: labs/yay/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Thu Nov 12 13:39:18 2015
@@ -52,5 +52,23 @@
             <artifactId>commons-collections</artifactId>
             <version>3.2.1</version>
         </dependency>
+        <dependency>
+            <groupId>com.google.guava</groupId>
+            <artifactId>guava</artifactId>
+            <version>18.0</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
+    <build>
+        <plugins>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-surefire-plugin</artifactId>
+                <version>2.18.1</version>
+                <configuration>
+                    <argLine>-Xmx8g</argLine>
+                </configuration>
+            </plugin>
+        </plugins>
+    </build>
 </project>
\ No newline at end of file

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Thu Nov 12 13:39:18 2015
@@ -163,7 +163,7 @@ public class BackPropagationLearningStra
         }
       };
       realMatrix.walkInOptimizedOrder(visitor);
-      if (updatedParameters[l]== null) {
+      if (updatedParameters[l] == null) {
         updatedParameters[l] = realMatrix;
       }
     }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java?rev=1714045&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java Thu Nov 12 13:39:18 2015
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.yay.Feature;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+import org.apache.yay.core.utils.ExamplesFactory;
+
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * An hot encoded {@link TrainingSet}, only indices values are stored
+ */
+public class EncodedTrainingSet extends TrainingSet<Double, Double> {
+  private final List<String> vocabulary;
+  private final int window;
+
+  public EncodedTrainingSet(Collection<TrainingExample<Double, Double>> samples, List<String> vocabulary, int window) {
+    super(samples);
+    this.vocabulary = vocabulary;
+    this.window = window;
+  }
+
+  @Override
+  public int size() {
+    return super.size();
+  }
+
+  @Override
+  public Iterator<TrainingExample<Double, Double>> iterator() {
+    return new Iterator<TrainingExample<Double, Double>>() {
+      @Override
+      public boolean hasNext() {
+        return EncodedTrainingSet.super.iterator().hasNext();
+      }
+
+      @Override
+      public TrainingExample<Double, Double> next() {
+        TrainingExample<Double, Double> sample = EncodedTrainingSet.super.iterator().next();
+        Collection<Feature<Double>> features = sample.getFeatures();
+        int vocabularySize = vocabulary.size();
+        Double[] outputs = new Double[vocabularySize * (window - 1)];
+        Double[] inputs = new Double[vocabularySize];
+        for (Feature<Double> feature : features) {
+          inputs = ConversionUtils.hotEncode(feature.getValue().intValue(), vocabularySize);
+          break;
+        }
+        int k = 0;
+        for (Double d : sample.getOutput()) {
+          Double[] currentOutput = ConversionUtils.hotEncode(d.intValue(), vocabularySize);
+          System.arraycopy(currentOutput, 0, outputs, k, currentOutput.length);
+          k += vocabularySize;
+        }
+        return ExamplesFactory.createDoubleArrayTrainingExample(outputs, inputs);
+      }
+    };
+  }
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Nov 12 13:39:18 2015
@@ -79,27 +79,43 @@ public class FeedForwardStrategy impleme
       // apply the activation function to each element in the matrix
       int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
       final ActivationFunction<Double> af = activationFunctionMap.get(idx);
-      RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
 
-        @Override
-        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
-
-        }
-
-        @Override
-        public double visit(int row, int column, double value) {
-          return af.apply(cm, value);
-        }
-
-        @Override
-        public double end() {
-          return 0;
-        }
-      };
-      x.walkInOptimizedOrder(visitor);
+      if (af instanceof SoftmaxActivationFunction) {
+        x = ((SoftmaxActivationFunction) af).applyMatrix(x);
+      } else {
+        x.walkInOptimizedOrder(new ActivationFunctionVisitor(af, cm));
+      }
       debugOutput[w] = x.getRowVector(0);
     }
     return debugOutput;
   }
 
+  private static class ActivationFunctionVisitor implements RealMatrixChangingVisitor {
+
+    private final ActivationFunction<Double> af;
+    private final RealMatrix matrix;
+
+    ActivationFunctionVisitor(ActivationFunction<Double> af, RealMatrix matrix) {
+      this.af = af;
+      this.matrix = matrix;
+    }
+
+    @Override
+    public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+    }
+
+    @Override
+    public double visit(int row, int column, double value) {
+      return af.apply(matrix, value);
+    }
+
+    @Override
+    public double end() {
+      return 0;
+    }
+
+
+  }
+
 }
\ No newline at end of file

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Thu Nov 12 13:39:18 2015
@@ -19,6 +19,9 @@
 package org.apache.yay.core;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.commons.math3.stat.descriptive.rank.Max;
 import org.apache.yay.ActivationFunction;
 
 import java.util.Map;
@@ -31,6 +34,25 @@ public class SoftmaxActivationFunction i
 
   private static final Map<RealMatrix, Double> cache = new WeakHashMap<RealMatrix, Double>();
 
+  private static final Max m = new Max();
+
+  private static final RealMatrixChangingVisitor expVisitor = new RealMatrixChangingVisitor() {
+    @Override
+    public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+    }
+
+    @Override
+    public double visit(int row, int column, double value) {
+      return Math.exp(value);
+    }
+
+    @Override
+    public double end() {
+      return 0;
+    }
+  };
+
   @Override
   public Double apply(RealMatrix weights, Double signal) {
     double num = Math.exp(signal);
@@ -38,18 +60,49 @@ public class SoftmaxActivationFunction i
     return num / den;
   }
 
+  public RealMatrix applyMatrix(RealMatrix weights) {
+
+    RealMatrix matrix = weights.copy();
+    double d = expDen(matrix);
+    final double finalD = d;
+    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+      @Override
+      public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+      }
+
+      @Override
+      public double visit(int row, int column, double value) {
+        return Math.exp(value) / finalD;
+      }
+
+      @Override
+      public double end() {
+        return 0;
+      }
+    });
+    return matrix;
+  }
+
+  private double expDen(RealMatrix matrix) {
+    double d = 0d;
+    for (int i = 0; i < matrix.getRowDimension(); i++) {
+      RealVector currentRow = matrix.getRowVector(i);
+      for (int j = 0; j < matrix.getColumnDimension(); j++) {
+        double entry = currentRow.getEntry(j);
+        d += Math.exp(entry);
+      }
+    }
+    return d;
+  }
+
   private double getDen(RealMatrix weights) {
     Double d = cache.get(weights);
-    if (d == null) {
-      double den = 0d;
-      for (int i = 0; i < weights.getRowDimension(); i++) {
-        double[] row1 = weights.getRow(i);
-        for (int j = 0; j < weights.getColumnDimension(); j++) {
-          den += Math.exp(row1[j]);
-        }
+    synchronized (cache) {
+      if (d == null) {
+        d = expDen(weights.copy());
+        cache.put(weights, d);
       }
-      d = den;
-      cache.put(weights, d);
     }
     return d;
   }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Thu Nov 12 13:39:18 2015
@@ -18,8 +18,6 @@
  */
 package org.apache.yay.core.utils;
 
-import java.util.ArrayList;
-import java.util.Collection;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.OpenMapRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
@@ -27,11 +25,21 @@ import org.apache.commons.math3.linear.R
 import org.apache.yay.Feature;
 import org.apache.yay.Input;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.WeakHashMap;
+
 /**
  * Temporary class for conversion between model objects and commons-math matrices/vectors
  */
 public class ConversionUtils {
 
+  private static final WeakHashMap<String, Double[]> wordCache = new WeakHashMap<String, Double[]>();
+  private static final WeakHashMap<String, Integer> vocabularyCache = new WeakHashMap<String, Integer>();
+
   /**
    * Converts a set of examples to a matrix of inputs with features
    *
@@ -82,7 +90,7 @@ public class ConversionUtils {
    * <code>T</code> objects.
    *
    * @param featureVector the vector of features
-   * @param <T> the type of features
+   * @param <T>           the type of features
    * @return a vector of Doubles
    */
   public static <T> Collection<T> toValuesCollection(Collection<Feature<T>> featureVector) {
@@ -107,4 +115,41 @@ public class ConversionUtils {
     }
     return doubles;
   }
+
+  public static Double[] hotEncode(byte[] word, List<String> vocabulary) {
+    String wordString = new String(word);
+    Double[] vector = wordCache.get(wordString);
+    if (vector == null) {
+      vector = new Double[vocabulary.size()];
+      Integer index = vocabularyCache.get(wordString);
+      if (index == null || index < 0) {
+        index = Collections.binarySearch(vocabulary, wordString);
+        vocabularyCache.put(wordString, index);
+      }
+      Arrays.fill(vector, 0d);
+      vector[index] = 1d;
+      wordCache.put(wordString, vector);
+    }
+    return vector;
+  }
+
+  public static Double[] hotEncode(int index, int size) {
+    Double[] vector = new Double[size];
+    Arrays.fill(vector, 0d);
+    vector[index] = 1d;
+    return vector;
+  }
+
+  public static String hotDecode(Double[] doubles, List<String> vocabulary) {
+    double max = -Double.MAX_VALUE;
+    int index = -1;
+    for (int i = 0; i < doubles.length; i++) {
+      Double aDouble = doubles[i];
+      if (aDouble > max) {
+        max = aDouble;
+        index = i;
+      }
+    }
+    return vocabulary.get(index);
+  }
 }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Thu Nov 12 13:39:18 2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core;
 
+import com.google.common.base.Splitter;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.ml.distance.CanberraDistance;
@@ -41,15 +42,27 @@ import java.io.FileWriter;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Queue;
 import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.regex.Pattern;
 
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -59,22 +72,32 @@ import static org.junit.Assert.assertNot
  */
 public class WordVectorsTest {
 
+  private static final boolean measure = false;
+
+  private static final boolean serialize = true;
+
   @Test
   public void testSGM() throws Exception {
-    Collection<String> sentences = getSentences();
-    assertFalse(sentences.isEmpty());
-    List<String> vocabulary = getVocabulary(sentences);
-    assertFalse(vocabulary.isEmpty());
-    Collections.sort(vocabulary);
-    Collection<String> fragments = getFragments(sentences, 4);
+
+    Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
+
+    System.out.println("reading fragments");
+    int window = 4;
+    Queue<List<byte[]>> fragments = getFragments(path, window);
     assertFalse(fragments.isEmpty());
+    System.out.println("generating vocabulary");
+    List<String> vocabulary = getVocabulary(path);
+    assertFalse(vocabulary.isEmpty());
 
-    TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
+    System.out.println("creating training set");
+    TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments, window);
+    fragments.clear();
     TrainingExample<Double, Double> next = trainingSet.iterator().next();
 
     int inputSize = next.getFeatures().size();
     int outputSize = next.getOutput().length;
-    int hiddenSize = 100;
+    int hiddenSize = 30;
+    System.out.println("initializing neural network");
     RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
 
     Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>();
@@ -83,137 +106,80 @@ public class WordVectorsTest {
     FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
     BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1,
             BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(),
-            100);
+            trainingSet.size());
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
 
+    System.out.println("learning...");
     RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet);
 
+    System.out.println("learning finished");
     RealMatrix wordVectors = learnedWeights[0];
 
     assertNotNull(wordVectors);
 
-    Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>();
-    measures.add(new EuclideanDistance());
-    measures.add(new CanberraDistance());
-    measures.add(new ChebyshevDistance());
-    measures.add(new ManhattanDistance());
-    measures.add(new EarthMoversDistance());
-    measures.add(new DistanceMeasure() {
-      private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
-
-      @Override
-      public double compute(double[] a, double[] b) {
-        return 1 / pearsonsCorrelation.correlation(a, b);
+    if (serialize) {
+      System.out.println("serializing word vectors");
+      BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
+      for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+        double[] a = wordVectors.getColumnVector(i).toArray();
+        String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
+        csq = csq.substring(1, csq.length() - 1);
+        bufferedWriter.append(csq);
+        bufferedWriter.append(",");
+        bufferedWriter.append(vocabulary.get(i - 1));
+        bufferedWriter.newLine();
       }
+      bufferedWriter.flush();
+      bufferedWriter.close();
+    }
 
-      @Override
-      public String toString() {
-        return "inverse pearson correlation distance measure";
-      }
-    });
-    measures.add(new DistanceMeasure() {
-      @Override
-      public double compute(double[] a, double[] b) {
-        double dp = 0.0;
-        double na = 0.0;
-        double nb = 0.0;
-        for (int i = 0; i < a.length; i++) {
-          dp += a[i] * b[i];
-          na += Math.pow(a[i], 2);
-          nb += Math.pow(b[i], 2);
+    if (measure) {
+      System.out.println("measuring similarities");
+      Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>();
+      measures.add(new EuclideanDistance());
+      measures.add(new CanberraDistance());
+      measures.add(new ChebyshevDistance());
+      measures.add(new ManhattanDistance());
+      measures.add(new EarthMoversDistance());
+      measures.add(new DistanceMeasure() {
+        private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
+
+        @Override
+        public double compute(double[] a, double[] b) {
+          return 1 / pearsonsCorrelation.correlation(a, b);
         }
-        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
-        return 1 / cosineSimilarity;
-      }
 
-      @Override
-      public String toString() {
-        return "inverse cosine similarity distance measure";
-      }
-    });
+        @Override
+        public String toString() {
+          return "inverse pearson correlation distance measure";
+        }
+      });
+      measures.add(new DistanceMeasure() {
+        @Override
+        public double compute(double[] a, double[] b) {
+          double dp = 0.0;
+          double na = 0.0;
+          double nb = 0.0;
+          for (int i = 0; i < a.length; i++) {
+            dp += a[i] * b[i];
+            na += Math.pow(a[i], 2);
+            nb += Math.pow(b[i], 2);
+          }
+          double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+          return 1 / cosineSimilarity;
+        }
 
-    for (DistanceMeasure distanceMeasure : measures) {
-      System.out.println("computing similarity using " + distanceMeasure);
-      computeSimilarities(vocabulary, wordVectors, distanceMeasure);
-    }
+        @Override
+        public String toString() {
+          return "inverse cosine similarity distance measure";
+        }
+      });
 
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
-    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
-      double[] a = wordVectors.getColumnVector(i).toArray();
-      String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
-      csq = csq.substring(1, csq.length() - 1);
-      bufferedWriter.append(csq);
-      bufferedWriter.append(",");
-      bufferedWriter.append(vocabulary.get(i-1));
-      bufferedWriter.newLine();
+      for (DistanceMeasure distanceMeasure : measures) {
+        System.out.println("computing similarity using " + distanceMeasure);
+        computeSimilarities(vocabulary, wordVectors, distanceMeasure);
+      }
     }
-
-    bufferedWriter.flush();
-    bufferedWriter.close();
-
-//    RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
-//
-//    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
-//    int m = 0;
-//    for (String word : vocabulary) {
-//      final Double[] doubles = hotEncode(word, vocabulary);
-//      Input<Double> input = new TrainingExample<Double, Double>() {
-//        @Override
-//        public ArrayList<Feature<Double>> getFeatures() {
-//          ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
-//          Feature<Double> byasFeature = new Feature<Double>();
-//          byasFeature.setValue(1d);
-//          features.add(byasFeature);
-//          for (Double d : doubles) {
-//            Feature<Double> f = new Feature<Double>();
-//            f.setValue(d);
-//            features.add(f);
-//          }
-//          return features;
-//        }
-//
-//        @Override
-//        public Double[] getOutput() {
-//          return new Double[0];
-//        }
-//      };
-//      Double[] predict = neuralNetwork.predict(input);
-//      assertNotNull(predict);
-//      double[] row = new double[predict.length];
-//      for (int x = 0; x < row.length; x++) {
-//        row[x] = predict[x];
-//      }
-//      mappingsMatrix.setRow(m, row);
-//      m++;
-//
-//      String vectorString = Arrays.toString(predict);
-//      bufferedWriter.append(vectorString);
-//      bufferedWriter.newLine();
-//
-//      Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
-//      assertNotNull(wordVec1);
-//      Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size());
-//      assertNotNull(wordVec2);
-//      Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size());
-//      assertNotNull(wordVec3);
-//
-//      String word1 = hotDecode(wordVec1, vocabulary);
-//      assertNotNull(word1);
-//      assertTrue(vocabulary.contains(word1));
-//      String word2 = hotDecode(wordVec2, vocabulary);
-//      assertNotNull(word2);
-//      assertTrue(vocabulary.contains(word2));
-//      String word3 = hotDecode(wordVec3, vocabulary);
-//      assertNotNull(word3);
-//      assertTrue(vocabulary.contains(word3));
-//
-//      System.out.println(word + " generates " + word1 + " " + word2 + " " + word3);
-//    }
-//    bufferedWriter.flush();
-//    bufferedWriter.close();
-//
-//    ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
-//    MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
   }
 
   private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) {
@@ -272,28 +238,17 @@ public class WordVectorsTest {
     }
   }
 
-  private String hotDecode(Double[] doubles, List<String> vocabulary) {
-    double max = -Double.MAX_VALUE;
-    int index = -1;
-    for (int i = 0; i < doubles.length; i++) {
-      Double aDouble = doubles[i];
-      if (aDouble > max) {
-        max = aDouble;
-        index = i;
-      }
-    }
-    return vocabulary.get(index);
-  }
-
-  private TrainingSet<Double, Double> createTrainingSet(List<String> vocabulary, Collection<String> fragments) {
-    Collection<TrainingExample<Double, Double>> samples = new LinkedList<TrainingExample<Double, Double>>();
-    for (String fragment : fragments) {
-      String[] tokens = fragment.split(" ");
-      String inputWord = null;
-      for (int i = 0; i < tokens.length; i++) {
-        List<String> outputWords = new LinkedList<String>();
-        for (int j = 0; j < tokens.length; j++) {
-          String token = tokens[i];
+  private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) {
+    long start = System.currentTimeMillis();
+    Path file = Paths.get("/Users/teofili/Desktop/ts.txt");
+    Collection<TrainingExample<Double, Double>> samples = new LinkedList<>();
+    List<byte[]> fragment;
+    while ((fragment = fragments.poll()) != null) {
+      byte[] inputWord = null;
+      for (int i = 0; i < fragment.size(); i++) {
+        List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
+        for (int j = 0; j < fragment.size(); j++) {
+          byte[] token = fragment.get(i);
           if (i == j) {
             inputWord = token;
           } else {
@@ -301,92 +256,152 @@ public class WordVectorsTest {
           }
         }
 
-        final Double[] input = hotEncode(inputWord, vocabulary);
-        final Double[] outputs = new Double[outputWords.size() * vocabulary.size()];
-        for (int k = 0; k < outputWords.size(); k++) {
-          Double[] doubles = hotEncode(outputWords.get(k), vocabulary);
-          for (int z = 0; z < doubles.length; z++) {
-            outputs[(k * doubles.length) + z] = doubles[z];
-          }
-        }
+        final byte[] finalInputWord = inputWord;
         samples.add(new TrainingExample<Double, Double>() {
           @Override
           public Double[] getOutput() {
-            return outputs;
+            Double[] doubles = new Double[window - 1];
+            for (int i = 0; i < doubles.length; i++) {
+              doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i)));
+            }
+            return doubles;
           }
 
           @Override
           public ArrayList<Feature<Double>> getFeatures() {
-            ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
-            Feature<Double> byasFeature = new Feature<Double>();
-            byasFeature.setValue(1d);
-            features.add(byasFeature);
-            for (Double d : input) {
-              Feature<Double> e = new Feature<Double>();
-              e.setValue(d);
-              features.add(e);
-            }
+            ArrayList<Feature<Double>> features = new ArrayList<>();
+            Feature<Double> e = new Feature<>();
+            e.setValue((double) vocabulary.indexOf(new String(finalInputWord)));
+            features.add(e);
             return features;
           }
         });
       }
     }
-    return new TrainingSet<Double, Double>(samples);
+    EncodedTrainingSet trainingSet = new EncodedTrainingSet(samples, vocabulary, window);
+
+    long end = System.currentTimeMillis();
+    System.out.println("training set created in " + (end - start) / 60000 + " minutes");
+
+    return trainingSet;
   }
 
-  private Double[] hotEncode(String word, List<String> vocabulary) {
-    Double[] vector = new Double[vocabulary.size()];
-    int index = Collections.binarySearch(vocabulary, word);
-    Arrays.fill(vector, 0d);
-    vector[index] = 1d;
-    return vector;
+
+  private List<String> getVocabulary(Path path) throws IOException {
+    long start = System.currentTimeMillis();
+    Set<String> vocabulary = new HashSet<String>();
+    SeekableByteChannel sbc = Files.newByteChannel(path);
+    ByteBuffer buf = ByteBuffer.allocate(100);
+    try {
+
+      String encoding = System.getProperty("file.encoding");
+      StringBuilder previous = new StringBuilder();
+      Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+      while (sbc.read(buf) > 0) {
+        buf.rewind();
+        CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+        String string = charBuffer.toString();
+        List<String> split = splitter.splitToList(string);
+        int splitSize = split.size();
+        if (splitSize > 1) {
+          String term = previous.append(split.get(0)).toString();
+          vocabulary.add(term.intern());
+          for (int i = 1; i < splitSize - 1; i++) {
+            String term2 = split.get(i);
+            vocabulary.add(term2.intern());
+          }
+          previous = new StringBuilder().append(split.get(splitSize - 1));
+        } else if (split.size() == 1) {
+          previous.append(string);
+        }
+        buf.flip();
+      }
+    } catch (IOException x) {
+      System.err.println("caught exception: " + x);
+    } finally {
+      sbc.close();
+      buf.clear();
+    }
+    long end = System.currentTimeMillis();
+    List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
+    Collections.sort(list);
+    System.out.println("vocabulary read in " + (end - start) / 60000 + " minutes (" + (list.size()) + ")");
+    return list;
   }
 
-  private List<String> getVocabulary(Collection<String> sentences) {
+  private List<String> getVocabulary(Collection<byte[]> sentences) {
+    long start = System.currentTimeMillis();
     List<String> vocabulary = new LinkedList<String>();
-    for (String sentence : sentences) {
-      for (String token : sentence.split(" ")) {
+    for (byte[] sentence : sentences) {
+      for (String token : new String(sentence).split(" ")) {
         if (!vocabulary.contains(token)) {
           vocabulary.add(token);
         }
       }
     }
+    System.out.println("sorting vocabulary");
     Collections.sort(vocabulary);
+    long end = System.currentTimeMillis();
+    System.out.println("vocabulary generated in " + (end - start) / 60000 + " minutes");
     return vocabulary;
   }
 
-  private Collection<String> getFragments(Collection<String> vocabulary, int w) {
-    Collection<String> fragments = new LinkedList<String>();
-    for (String sentence : vocabulary) {
-      while (sentence.length() > 0) {
-        int idx = 0;
-        for (int i = 0; i < w; i++) {
-          idx = sentence.indexOf(' ', idx + 1);
-        }
-        if (idx > 0) {
-          String fragment = sentence.substring(0, idx);
-          if (fragment.split(" ").length == 4) {
+  private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException {
+    long start = System.currentTimeMillis();
+    Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<List<byte[]>>();
+
+    SeekableByteChannel sbc = Files.newByteChannel(path);
+    ByteBuffer buf = ByteBuffer.allocate(100);
+    try {
+
+      String encoding = System.getProperty("file.encoding");
+      StringBuilder previous = new StringBuilder();
+      Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+      int lastConsumedIndex = -1;
+      while (sbc.read(buf) > 0) {
+        buf.rewind();
+        CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+        String string = charBuffer.toString();
+        List<String> split = splitter.splitToList(string);
+        int splitSize = split.size();
+        if (splitSize > w) {
+          for (int j = 0; j < splitSize - w; j++) {
+            List<byte[]> fragment = new ArrayList<byte[]>(w);
+            fragment.add(previous.append(split.get(j)).toString().getBytes());
+            for (int i = 1; i < w; i++) {
+              fragment.add(split.get(i + j).getBytes());
+            }
+            // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
+            lastConsumedIndex = j + w;
             fragments.add(fragment);
-            sentence = sentence.substring(sentence.indexOf(' ') + 1);
-          }
-        } else {
-          if (sentence.split(" ").length == 4) {
-            fragments.add(sentence);
-            sentence = "";
+            previous = new StringBuilder();
           }
+          previous = new StringBuilder().append(split.get(splitSize - 1));
+        } else if (split.size() == w) {
+          previous.append(string);
         }
+        buf.flip();
       }
+    } catch (IOException x) {
+      System.err.println("caught exception: " + x);
+    } finally {
+      sbc.close();
+      buf.clear();
     }
+    long end = System.currentTimeMillis();
+    System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
     return fragments;
   }
 
   private Collection<String> getSentences() throws IOException {
+    Collection<String> sentences = new LinkedList<String>();
+
     InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt");
     BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream));
-    Collection<String> sentences = new LinkedList<String>();
     String line;
     while ((line = bufferedReader.readLine()) != null) {
-      sentences.add(line.toLowerCase());
+      String cleanLine = line.toLowerCase().replaceAll("\\.", "").replaceAll("\\;", "").replaceAll("\\,", "").replaceAll("\\:", "");
+      sentences.add(cleanLine);
     }
     return sentences;
   }

Modified: labs/yay/trunk/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/pom.xml (original)
+++ labs/yay/trunk/pom.xml Thu Nov 12 13:39:18 2015
@@ -152,8 +152,8 @@
         <artifactId>maven-compiler-plugin</artifactId>
         <version>2.0.2</version>
         <configuration>
-          <source>1.6</source>
-          <target>1.6</target>
+          <source>1.8</source>
+          <target>1.8</target>
           <encoding>UTF-8</encoding>
         </configuration>
       </plugin>



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org