You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/11/12 14:39:19 UTC
svn commit: r1714045 - in /labs/yay/trunk: ./ core/
core/src/main/java/org/apache/yay/core/
core/src/main/java/org/apache/yay/core/utils/
core/src/test/java/org/apache/yay/core/
Author: tommaso
Date: Thu Nov 12 13:39:18 2015
New Revision: 1714045
URL: http://svn.apache.org/viewvc?rev=1714045&view=rev
Log:
performance improvements on training set creation and softmax
Added:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java (with props)
Modified:
labs/yay/trunk/core/pom.xml
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
labs/yay/trunk/pom.xml
Modified: labs/yay/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Thu Nov 12 13:39:18 2015
@@ -52,5 +52,23 @@
<artifactId>commons-collections</artifactId>
<version>3.2.1</version>
</dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <version>18.0</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.18.1</version>
+ <configuration>
+ <argLine>-Xmx8g</argLine>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
</project>
\ No newline at end of file
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Thu Nov 12 13:39:18 2015
@@ -163,7 +163,7 @@ public class BackPropagationLearningStra
}
};
realMatrix.walkInOptimizedOrder(visitor);
- if (updatedParameters[l]== null) {
+ if (updatedParameters[l] == null) {
updatedParameters[l] = realMatrix;
}
}
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java?rev=1714045&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java Thu Nov 12 13:39:18 2015
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.yay.Feature;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+import org.apache.yay.core.utils.ExamplesFactory;
+
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * An hot encoded {@link TrainingSet}, only indices values are stored
+ */
+public class EncodedTrainingSet extends TrainingSet<Double, Double> {
+ private final List<String> vocabulary;
+ private final int window;
+
+ public EncodedTrainingSet(Collection<TrainingExample<Double, Double>> samples, List<String> vocabulary, int window) {
+ super(samples);
+ this.vocabulary = vocabulary;
+ this.window = window;
+ }
+
+ @Override
+ public int size() {
+ return super.size();
+ }
+
+ @Override
+ public Iterator<TrainingExample<Double, Double>> iterator() {
+ return new Iterator<TrainingExample<Double, Double>>() {
+ @Override
+ public boolean hasNext() {
+ return EncodedTrainingSet.super.iterator().hasNext();
+ }
+
+ @Override
+ public TrainingExample<Double, Double> next() {
+ TrainingExample<Double, Double> sample = EncodedTrainingSet.super.iterator().next();
+ Collection<Feature<Double>> features = sample.getFeatures();
+ int vocabularySize = vocabulary.size();
+ Double[] outputs = new Double[vocabularySize * (window - 1)];
+ Double[] inputs = new Double[vocabularySize];
+ for (Feature<Double> feature : features) {
+ inputs = ConversionUtils.hotEncode(feature.getValue().intValue(), vocabularySize);
+ break;
+ }
+ int k = 0;
+ for (Double d : sample.getOutput()) {
+ Double[] currentOutput = ConversionUtils.hotEncode(d.intValue(), vocabularySize);
+ System.arraycopy(currentOutput, 0, outputs, k, currentOutput.length);
+ k += vocabularySize;
+ }
+ return ExamplesFactory.createDoubleArrayTrainingExample(outputs, inputs);
+ }
+ };
+ }
+}
Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java
------------------------------------------------------------------------------
svn:eol-style = native
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Nov 12 13:39:18 2015
@@ -79,27 +79,43 @@ public class FeedForwardStrategy impleme
// apply the activation function to each element in the matrix
int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
final ActivationFunction<Double> af = activationFunctionMap.get(idx);
- RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
- @Override
- public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
-
- }
-
- @Override
- public double visit(int row, int column, double value) {
- return af.apply(cm, value);
- }
-
- @Override
- public double end() {
- return 0;
- }
- };
- x.walkInOptimizedOrder(visitor);
+ if (af instanceof SoftmaxActivationFunction) {
+ x = ((SoftmaxActivationFunction) af).applyMatrix(x);
+ } else {
+ x.walkInOptimizedOrder(new ActivationFunctionVisitor(af, cm));
+ }
debugOutput[w] = x.getRowVector(0);
}
return debugOutput;
}
+ private static class ActivationFunctionVisitor implements RealMatrixChangingVisitor {
+
+ private final ActivationFunction<Double> af;
+ private final RealMatrix matrix;
+
+ ActivationFunctionVisitor(ActivationFunction<Double> af, RealMatrix matrix) {
+ this.af = af;
+ this.matrix = matrix;
+ }
+
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return af.apply(matrix, value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+
+
+ }
+
}
\ No newline at end of file
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Thu Nov 12 13:39:18 2015
@@ -19,6 +19,9 @@
package org.apache.yay.core;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.commons.math3.stat.descriptive.rank.Max;
import org.apache.yay.ActivationFunction;
import java.util.Map;
@@ -31,6 +34,25 @@ public class SoftmaxActivationFunction i
private static final Map<RealMatrix, Double> cache = new WeakHashMap<RealMatrix, Double>();
+ private static final Max m = new Max();
+
+ private static final RealMatrixChangingVisitor expVisitor = new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return Math.exp(value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ };
+
@Override
public Double apply(RealMatrix weights, Double signal) {
double num = Math.exp(signal);
@@ -38,18 +60,49 @@ public class SoftmaxActivationFunction i
return num / den;
}
+ public RealMatrix applyMatrix(RealMatrix weights) {
+
+ RealMatrix matrix = weights.copy();
+ double d = expDen(matrix);
+ final double finalD = d;
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return Math.exp(value) / finalD;
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
+ }
+
+ private double expDen(RealMatrix matrix) {
+ double d = 0d;
+ for (int i = 0; i < matrix.getRowDimension(); i++) {
+ RealVector currentRow = matrix.getRowVector(i);
+ for (int j = 0; j < matrix.getColumnDimension(); j++) {
+ double entry = currentRow.getEntry(j);
+ d += Math.exp(entry);
+ }
+ }
+ return d;
+ }
+
private double getDen(RealMatrix weights) {
Double d = cache.get(weights);
- if (d == null) {
- double den = 0d;
- for (int i = 0; i < weights.getRowDimension(); i++) {
- double[] row1 = weights.getRow(i);
- for (int j = 0; j < weights.getColumnDimension(); j++) {
- den += Math.exp(row1[j]);
- }
+ synchronized (cache) {
+ if (d == null) {
+ d = expDen(weights.copy());
+ cache.put(weights, d);
}
- d = den;
- cache.put(weights, d);
}
return d;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Thu Nov 12 13:39:18 2015
@@ -18,8 +18,6 @@
*/
package org.apache.yay.core.utils;
-import java.util.ArrayList;
-import java.util.Collection;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
@@ -27,11 +25,21 @@ import org.apache.commons.math3.linear.R
import org.apache.yay.Feature;
import org.apache.yay.Input;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.WeakHashMap;
+
/**
* Temporary class for conversion between model objects and commons-math matrices/vectors
*/
public class ConversionUtils {
+ private static final WeakHashMap<String, Double[]> wordCache = new WeakHashMap<String, Double[]>();
+ private static final WeakHashMap<String, Integer> vocabularyCache = new WeakHashMap<String, Integer>();
+
/**
* Converts a set of examples to a matrix of inputs with features
*
@@ -82,7 +90,7 @@ public class ConversionUtils {
* <code>T</code> objects.
*
* @param featureVector the vector of features
- * @param <T> the type of features
+ * @param <T> the type of features
* @return a vector of Doubles
*/
public static <T> Collection<T> toValuesCollection(Collection<Feature<T>> featureVector) {
@@ -107,4 +115,41 @@ public class ConversionUtils {
}
return doubles;
}
+
+ public static Double[] hotEncode(byte[] word, List<String> vocabulary) {
+ String wordString = new String(word);
+ Double[] vector = wordCache.get(wordString);
+ if (vector == null) {
+ vector = new Double[vocabulary.size()];
+ Integer index = vocabularyCache.get(wordString);
+ if (index == null || index < 0) {
+ index = Collections.binarySearch(vocabulary, wordString);
+ vocabularyCache.put(wordString, index);
+ }
+ Arrays.fill(vector, 0d);
+ vector[index] = 1d;
+ wordCache.put(wordString, vector);
+ }
+ return vector;
+ }
+
+ public static Double[] hotEncode(int index, int size) {
+ Double[] vector = new Double[size];
+ Arrays.fill(vector, 0d);
+ vector[index] = 1d;
+ return vector;
+ }
+
+ public static String hotDecode(Double[] doubles, List<String> vocabulary) {
+ double max = -Double.MAX_VALUE;
+ int index = -1;
+ for (int i = 0; i < doubles.length; i++) {
+ Double aDouble = doubles[i];
+ if (aDouble > max) {
+ max = aDouble;
+ index = i;
+ }
+ }
+ return vocabulary.get(index);
+ }
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Thu Nov 12 13:39:18 2015
@@ -18,6 +18,7 @@
*/
package org.apache.yay.core;
+import com.google.common.base.Splitter;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.distance.CanberraDistance;
@@ -41,15 +42,27 @@ import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Queue;
import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.regex.Pattern;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@@ -59,22 +72,32 @@ import static org.junit.Assert.assertNot
*/
public class WordVectorsTest {
+ private static final boolean measure = false;
+
+ private static final boolean serialize = true;
+
@Test
public void testSGM() throws Exception {
- Collection<String> sentences = getSentences();
- assertFalse(sentences.isEmpty());
- List<String> vocabulary = getVocabulary(sentences);
- assertFalse(vocabulary.isEmpty());
- Collections.sort(vocabulary);
- Collection<String> fragments = getFragments(sentences, 4);
+
+ Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
+
+ System.out.println("reading fragments");
+ int window = 4;
+ Queue<List<byte[]>> fragments = getFragments(path, window);
assertFalse(fragments.isEmpty());
+ System.out.println("generating vocabulary");
+ List<String> vocabulary = getVocabulary(path);
+ assertFalse(vocabulary.isEmpty());
- TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
+ System.out.println("creating training set");
+ TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments, window);
+ fragments.clear();
TrainingExample<Double, Double> next = trainingSet.iterator().next();
int inputSize = next.getFeatures().size();
int outputSize = next.getOutput().length;
- int hiddenSize = 100;
+ int hiddenSize = 30;
+ System.out.println("initializing neural network");
RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>();
@@ -83,137 +106,80 @@ public class WordVectorsTest {
FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1,
BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(),
- 100);
+ trainingSet.size());
NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
+ System.out.println("learning...");
RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet);
+ System.out.println("learning finished");
RealMatrix wordVectors = learnedWeights[0];
assertNotNull(wordVectors);
- Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>();
- measures.add(new EuclideanDistance());
- measures.add(new CanberraDistance());
- measures.add(new ChebyshevDistance());
- measures.add(new ManhattanDistance());
- measures.add(new EarthMoversDistance());
- measures.add(new DistanceMeasure() {
- private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
-
- @Override
- public double compute(double[] a, double[] b) {
- return 1 / pearsonsCorrelation.correlation(a, b);
+ if (serialize) {
+ System.out.println("serializing word vectors");
+ BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
+ for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+ double[] a = wordVectors.getColumnVector(i).toArray();
+ String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
+ csq = csq.substring(1, csq.length() - 1);
+ bufferedWriter.append(csq);
+ bufferedWriter.append(",");
+ bufferedWriter.append(vocabulary.get(i - 1));
+ bufferedWriter.newLine();
}
+ bufferedWriter.flush();
+ bufferedWriter.close();
+ }
- @Override
- public String toString() {
- return "inverse pearson correlation distance measure";
- }
- });
- measures.add(new DistanceMeasure() {
- @Override
- public double compute(double[] a, double[] b) {
- double dp = 0.0;
- double na = 0.0;
- double nb = 0.0;
- for (int i = 0; i < a.length; i++) {
- dp += a[i] * b[i];
- na += Math.pow(a[i], 2);
- nb += Math.pow(b[i], 2);
+ if (measure) {
+ System.out.println("measuring similarities");
+ Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>();
+ measures.add(new EuclideanDistance());
+ measures.add(new CanberraDistance());
+ measures.add(new ChebyshevDistance());
+ measures.add(new ManhattanDistance());
+ measures.add(new EarthMoversDistance());
+ measures.add(new DistanceMeasure() {
+ private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
+
+ @Override
+ public double compute(double[] a, double[] b) {
+ return 1 / pearsonsCorrelation.correlation(a, b);
}
- double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
- return 1 / cosineSimilarity;
- }
- @Override
- public String toString() {
- return "inverse cosine similarity distance measure";
- }
- });
+ @Override
+ public String toString() {
+ return "inverse pearson correlation distance measure";
+ }
+ });
+ measures.add(new DistanceMeasure() {
+ @Override
+ public double compute(double[] a, double[] b) {
+ double dp = 0.0;
+ double na = 0.0;
+ double nb = 0.0;
+ for (int i = 0; i < a.length; i++) {
+ dp += a[i] * b[i];
+ na += Math.pow(a[i], 2);
+ nb += Math.pow(b[i], 2);
+ }
+ double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+ return 1 / cosineSimilarity;
+ }
- for (DistanceMeasure distanceMeasure : measures) {
- System.out.println("computing similarity using " + distanceMeasure);
- computeSimilarities(vocabulary, wordVectors, distanceMeasure);
- }
+ @Override
+ public String toString() {
+ return "inverse cosine similarity distance measure";
+ }
+ });
- BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
- for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
- double[] a = wordVectors.getColumnVector(i).toArray();
- String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
- csq = csq.substring(1, csq.length() - 1);
- bufferedWriter.append(csq);
- bufferedWriter.append(",");
- bufferedWriter.append(vocabulary.get(i-1));
- bufferedWriter.newLine();
+ for (DistanceMeasure distanceMeasure : measures) {
+ System.out.println("computing similarity using " + distanceMeasure);
+ computeSimilarities(vocabulary, wordVectors, distanceMeasure);
+ }
}
-
- bufferedWriter.flush();
- bufferedWriter.close();
-
-// RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
-//
-// BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
-// int m = 0;
-// for (String word : vocabulary) {
-// final Double[] doubles = hotEncode(word, vocabulary);
-// Input<Double> input = new TrainingExample<Double, Double>() {
-// @Override
-// public ArrayList<Feature<Double>> getFeatures() {
-// ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
-// Feature<Double> byasFeature = new Feature<Double>();
-// byasFeature.setValue(1d);
-// features.add(byasFeature);
-// for (Double d : doubles) {
-// Feature<Double> f = new Feature<Double>();
-// f.setValue(d);
-// features.add(f);
-// }
-// return features;
-// }
-//
-// @Override
-// public Double[] getOutput() {
-// return new Double[0];
-// }
-// };
-// Double[] predict = neuralNetwork.predict(input);
-// assertNotNull(predict);
-// double[] row = new double[predict.length];
-// for (int x = 0; x < row.length; x++) {
-// row[x] = predict[x];
-// }
-// mappingsMatrix.setRow(m, row);
-// m++;
-//
-// String vectorString = Arrays.toString(predict);
-// bufferedWriter.append(vectorString);
-// bufferedWriter.newLine();
-//
-// Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
-// assertNotNull(wordVec1);
-// Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size());
-// assertNotNull(wordVec2);
-// Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size());
-// assertNotNull(wordVec3);
-//
-// String word1 = hotDecode(wordVec1, vocabulary);
-// assertNotNull(word1);
-// assertTrue(vocabulary.contains(word1));
-// String word2 = hotDecode(wordVec2, vocabulary);
-// assertNotNull(word2);
-// assertTrue(vocabulary.contains(word2));
-// String word3 = hotDecode(wordVec3, vocabulary);
-// assertNotNull(word3);
-// assertTrue(vocabulary.contains(word3));
-//
-// System.out.println(word + " generates " + word1 + " " + word2 + " " + word3);
-// }
-// bufferedWriter.flush();
-// bufferedWriter.close();
-//
-// ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
-// MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
}
private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) {
@@ -272,28 +238,17 @@ public class WordVectorsTest {
}
}
- private String hotDecode(Double[] doubles, List<String> vocabulary) {
- double max = -Double.MAX_VALUE;
- int index = -1;
- for (int i = 0; i < doubles.length; i++) {
- Double aDouble = doubles[i];
- if (aDouble > max) {
- max = aDouble;
- index = i;
- }
- }
- return vocabulary.get(index);
- }
-
- private TrainingSet<Double, Double> createTrainingSet(List<String> vocabulary, Collection<String> fragments) {
- Collection<TrainingExample<Double, Double>> samples = new LinkedList<TrainingExample<Double, Double>>();
- for (String fragment : fragments) {
- String[] tokens = fragment.split(" ");
- String inputWord = null;
- for (int i = 0; i < tokens.length; i++) {
- List<String> outputWords = new LinkedList<String>();
- for (int j = 0; j < tokens.length; j++) {
- String token = tokens[i];
+ private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) {
+ long start = System.currentTimeMillis();
+ Path file = Paths.get("/Users/teofili/Desktop/ts.txt");
+ Collection<TrainingExample<Double, Double>> samples = new LinkedList<>();
+ List<byte[]> fragment;
+ while ((fragment = fragments.poll()) != null) {
+ byte[] inputWord = null;
+ for (int i = 0; i < fragment.size(); i++) {
+ List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
+ for (int j = 0; j < fragment.size(); j++) {
+ byte[] token = fragment.get(i);
if (i == j) {
inputWord = token;
} else {
@@ -301,92 +256,152 @@ public class WordVectorsTest {
}
}
- final Double[] input = hotEncode(inputWord, vocabulary);
- final Double[] outputs = new Double[outputWords.size() * vocabulary.size()];
- for (int k = 0; k < outputWords.size(); k++) {
- Double[] doubles = hotEncode(outputWords.get(k), vocabulary);
- for (int z = 0; z < doubles.length; z++) {
- outputs[(k * doubles.length) + z] = doubles[z];
- }
- }
+ final byte[] finalInputWord = inputWord;
samples.add(new TrainingExample<Double, Double>() {
@Override
public Double[] getOutput() {
- return outputs;
+ Double[] doubles = new Double[window - 1];
+ for (int i = 0; i < doubles.length; i++) {
+ doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i)));
+ }
+ return doubles;
}
@Override
public ArrayList<Feature<Double>> getFeatures() {
- ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
- Feature<Double> byasFeature = new Feature<Double>();
- byasFeature.setValue(1d);
- features.add(byasFeature);
- for (Double d : input) {
- Feature<Double> e = new Feature<Double>();
- e.setValue(d);
- features.add(e);
- }
+ ArrayList<Feature<Double>> features = new ArrayList<>();
+ Feature<Double> e = new Feature<>();
+ e.setValue((double) vocabulary.indexOf(new String(finalInputWord)));
+ features.add(e);
return features;
}
});
}
}
- return new TrainingSet<Double, Double>(samples);
+ EncodedTrainingSet trainingSet = new EncodedTrainingSet(samples, vocabulary, window);
+
+ long end = System.currentTimeMillis();
+ System.out.println("training set created in " + (end - start) / 60000 + " minutes");
+
+ return trainingSet;
}
- private Double[] hotEncode(String word, List<String> vocabulary) {
- Double[] vector = new Double[vocabulary.size()];
- int index = Collections.binarySearch(vocabulary, word);
- Arrays.fill(vector, 0d);
- vector[index] = 1d;
- return vector;
+
+ private List<String> getVocabulary(Path path) throws IOException {
+ long start = System.currentTimeMillis();
+ Set<String> vocabulary = new HashSet<String>();
+ SeekableByteChannel sbc = Files.newByteChannel(path);
+ ByteBuffer buf = ByteBuffer.allocate(100);
+ try {
+
+ String encoding = System.getProperty("file.encoding");
+ StringBuilder previous = new StringBuilder();
+ Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+ while (sbc.read(buf) > 0) {
+ buf.rewind();
+ CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+ String string = charBuffer.toString();
+ List<String> split = splitter.splitToList(string);
+ int splitSize = split.size();
+ if (splitSize > 1) {
+ String term = previous.append(split.get(0)).toString();
+ vocabulary.add(term.intern());
+ for (int i = 1; i < splitSize - 1; i++) {
+ String term2 = split.get(i);
+ vocabulary.add(term2.intern());
+ }
+ previous = new StringBuilder().append(split.get(splitSize - 1));
+ } else if (split.size() == 1) {
+ previous.append(string);
+ }
+ buf.flip();
+ }
+ } catch (IOException x) {
+ System.err.println("caught exception: " + x);
+ } finally {
+ sbc.close();
+ buf.clear();
+ }
+ long end = System.currentTimeMillis();
+ List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
+ Collections.sort(list);
+ System.out.println("vocabulary read in " + (end - start) / 60000 + " minutes (" + (list.size()) + ")");
+ return list;
}
- private List<String> getVocabulary(Collection<String> sentences) {
+ private List<String> getVocabulary(Collection<byte[]> sentences) {
+ long start = System.currentTimeMillis();
List<String> vocabulary = new LinkedList<String>();
- for (String sentence : sentences) {
- for (String token : sentence.split(" ")) {
+ for (byte[] sentence : sentences) {
+ for (String token : new String(sentence).split(" ")) {
if (!vocabulary.contains(token)) {
vocabulary.add(token);
}
}
}
+ System.out.println("sorting vocabulary");
Collections.sort(vocabulary);
+ long end = System.currentTimeMillis();
+ System.out.println("vocabulary generated in " + (end - start) / 60000 + " minutes");
return vocabulary;
}
- private Collection<String> getFragments(Collection<String> vocabulary, int w) {
- Collection<String> fragments = new LinkedList<String>();
- for (String sentence : vocabulary) {
- while (sentence.length() > 0) {
- int idx = 0;
- for (int i = 0; i < w; i++) {
- idx = sentence.indexOf(' ', idx + 1);
- }
- if (idx > 0) {
- String fragment = sentence.substring(0, idx);
- if (fragment.split(" ").length == 4) {
+ private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException {
+ long start = System.currentTimeMillis();
+ Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<List<byte[]>>();
+
+ SeekableByteChannel sbc = Files.newByteChannel(path);
+ ByteBuffer buf = ByteBuffer.allocate(100);
+ try {
+
+ String encoding = System.getProperty("file.encoding");
+ StringBuilder previous = new StringBuilder();
+ Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+ int lastConsumedIndex = -1;
+ while (sbc.read(buf) > 0) {
+ buf.rewind();
+ CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+ String string = charBuffer.toString();
+ List<String> split = splitter.splitToList(string);
+ int splitSize = split.size();
+ if (splitSize > w) {
+ for (int j = 0; j < splitSize - w; j++) {
+ List<byte[]> fragment = new ArrayList<byte[]>(w);
+ fragment.add(previous.append(split.get(j)).toString().getBytes());
+ for (int i = 1; i < w; i++) {
+ fragment.add(split.get(i + j).getBytes());
+ }
+ // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
+ lastConsumedIndex = j + w;
fragments.add(fragment);
- sentence = sentence.substring(sentence.indexOf(' ') + 1);
- }
- } else {
- if (sentence.split(" ").length == 4) {
- fragments.add(sentence);
- sentence = "";
+ previous = new StringBuilder();
}
+ previous = new StringBuilder().append(split.get(splitSize - 1));
+ } else if (split.size() == w) {
+ previous.append(string);
}
+ buf.flip();
}
+ } catch (IOException x) {
+ System.err.println("caught exception: " + x);
+ } finally {
+ sbc.close();
+ buf.clear();
}
+ long end = System.currentTimeMillis();
+ System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
return fragments;
}
private Collection<String> getSentences() throws IOException {
+ Collection<String> sentences = new LinkedList<String>();
+
InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream));
- Collection<String> sentences = new LinkedList<String>();
String line;
while ((line = bufferedReader.readLine()) != null) {
- sentences.add(line.toLowerCase());
+ String cleanLine = line.toLowerCase().replaceAll("\\.", "").replaceAll("\\;", "").replaceAll("\\,", "").replaceAll("\\:", "");
+ sentences.add(cleanLine);
}
return sentences;
}
Modified: labs/yay/trunk/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff
==============================================================================
--- labs/yay/trunk/pom.xml (original)
+++ labs/yay/trunk/pom.xml Thu Nov 12 13:39:18 2015
@@ -152,8 +152,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>2.0.2</version>
<configuration>
- <source>1.6</source>
- <target>1.6</target>
+ <source>1.8</source>
+ <target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org