You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/02/18 11:13:42 UTC
svn commit: r1731036 - in /labs/yay/trunk/core: ./
src/main/java/org/apache/yay/ src/test/java/org/apache/yay/
Author: tommaso
Date: Thu Feb 18 10:13:42 2016
New Revision: 1731036
URL: http://svn.apache.org/viewvc?rev=1731036&view=rev
Log:
refactored SFFNN to MLN, added ReLU function and (compact) skip-gram
Added:
labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java
- copied, changed from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java (with props)
labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (with props)
labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
- copied, changed from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (with props)
Removed:
labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java
Modified:
labs/yay/trunk/core/pom.xml
Modified: labs/yay/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1731036&r1=1731035&r2=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Thu Feb 18 10:13:42 2016
@@ -51,7 +51,6 @@
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>18.0</version>
- <scope>test</scope>
</dependency>
</dependencies>
<build>
Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java (from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java&r1=1724846&r2=1731036&rev=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java Thu Feb 18 10:13:42 2016
@@ -28,13 +28,11 @@ import java.util.Arrays;
import java.util.Random;
/**
- * A shallow feed forward neural network.
+ * A multi layer feed forward neural network.
* It learns its weights through backpropagation algorithm via stochastic gradient descent applied to a collection of
* training samples.
- * Each example is a real vectors whose first N elements (identified by the no. of network input units) are the actual
- * outputs and the remaining elements are the input features.
*/
-public class ShallowFeedForwardNeuralNetwork {
+public class MultiLayerNetwork {
private final Configuration configuration;
@@ -51,12 +49,12 @@ public class ShallowFeedForwardNeuralNet
*/
private RealMatrix[] weights;
- public ShallowFeedForwardNeuralNetwork(Configuration configuration) {
+ public MultiLayerNetwork(Configuration configuration) {
this.configuration = configuration;
this.weights = createRandomWeights();
}
- public ShallowFeedForwardNeuralNetwork(Configuration configuration, RealMatrix[] weights) {
+ public MultiLayerNetwork(Configuration configuration, RealMatrix[] weights) {
this.configuration = configuration;
this.weights = weights;
}
@@ -143,7 +141,7 @@ public class ShallowFeedForwardNeuralNet
if (Double.POSITIVE_INFINITY == newCost) {
throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
- } else if (iterations > 1 && (cost == newCost || newCost < configuration.threshold || iterations > configuration.maxIterations)) {
+ } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
break;
} else if (Double.isNaN(newCost)) {
@@ -204,13 +202,13 @@ public class ShallowFeedForwardNeuralNet
}
// compute derivatives
- for (int i = 0; i < deltas.length; i++) {
- ds[i] = deltas[i].scalarMultiply(1d / size);
- }
+// for (int i = 0; i < deltas.length; i++) {
+// ds[i] = deltas[i].scalarMultiply(1d / size);
+// }
// regularization
int l = 0;
- for (RealMatrix d : ds) {
+ for (RealMatrix d : deltas) {
final int finalL = l;
d.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@Override
@@ -221,7 +219,7 @@ public class ShallowFeedForwardNeuralNet
@Override
public double visit(int row, int column, double value) {
if (column != 0) {
- return value + configuration.alpha * weights[finalL].getEntry(row, column);
+ return value + configuration.alpha * weights[finalL].getEntry(row, column); // assuming regularization factor == learning rate
} else {
return value;
}
@@ -235,7 +233,7 @@ public class ShallowFeedForwardNeuralNet
l++;
}
- return ds;
+ return deltas;
}
private RealVector calculateDeltaVector(RealMatrix weight, RealVector activationsVector, RealVector nextLayerDelta) {
@@ -294,7 +292,17 @@ public class ShallowFeedForwardNeuralNet
res += yo * Math.log(ho) + (1d - yo)
* Math.log(1d - ho);
}
+
return (-1d / size) * res;
+
+// Double res = 0d;
+//
+// for (int i = 0; i < predictedOutput.length; i++) {
+// Double so = expectedOutput[i];
+// Double po = predictedOutput[i];
+// res -= so * Math.log(po);
+// }
+// return res;
}
// --- feed forward ---
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+
+/**
+ * Rectifier (aka ReLU) activation function
+ */
+public class RectifierFunction implements ActivationFunction {
+ @Override
+ public RealMatrix applyMatrix(RealMatrix weights) {
+ RealMatrix matrix = weights.copy();
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return Math.max(0, value);
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ return matrix;
+ }
+}
Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,549 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import com.google.common.base.Splitter;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+import org.apache.commons.math3.linear.RealMatrixPreservingVisitor;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.regex.Pattern;
+
+/**
+ * A skip-gram neural network.
+ * It learns its weights through backpropagation algorithm via batch gradient descent applied to a collection of
+ * hot encoded training samples.
+ */
+public class SkipGramNetwork {
+
+ private final Configuration configuration;
+ private final RectifierFunction rectifierFunction = new RectifierFunction();
+ private final SoftmaxActivationFunction softmaxActivationFunction = new SoftmaxActivationFunction();
+
+ /**
+ * Each RealMatrix maps weights between two layers.
+ * E.g.: weights[0] controls function mapping from layer 0 to layer 1.
+ * If network has 4 units in layer 1 and 5 units in layer 2, then weights[0] will be of dimension 5x4, plus bias terms.
+ * A network having layers with 3, 4 and 2 units each will have the following weights matrix dimensions:
+ * - weights[0] : 4x3
+ * - weights[1] : 2x4
+ * <p>
+ * the first row of weighs[0]Â matrix holds the weights of each neuron in the first neuron of the second layer,
+ * the second row of weighs[0]Â holds the weights of each neuron in the second neuron of the second layer, etc.
+ */
+ private RealMatrix[] weights;
+
+
+ private SkipGramNetwork(Configuration configuration) {
+ this.configuration = configuration;
+ this.weights = createRandomWeights();
+ }
+
+ public RealMatrix[] getWeights() {
+ return weights;
+ }
+
+ public List<String> getVocabulary() {
+ return configuration.vocabulary;
+ }
+
+ private RealMatrix[] createRandomWeights() {
+ Random r = new Random();
+ int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs};
+ int[] layers = new int[conf.length];
+ for (int i = 0; i < layers.length; i++) {
+ layers[i] = conf[i] + (i < layers.length - 1 ? 1 : 0);
+ }
+ int weightsCount = layers.length - 1;
+
+ RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+
+ for (int i = 0; i < weightsCount; i++) {
+
+ RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]);
+ final int finalI = i;
+ matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ if (finalI != weightsCount - 1 && row == 0) {
+ return 0d;
+ } else if (column == 0) {
+ return 1d;
+ }
+ return r.nextInt(100) / 101d;
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+
+ initialWeights[i] = matrix;
+ }
+ return initialWeights;
+
+ }
+
+ // --- batch gradient descent ---
+
+ /**
+ * perform weights learning from the training examples using batch gradient descent algorithm
+ *
+ * @param samples the training examples
+ * @return the final cost with the updated weights
+ * @throws Exception if BGD fails to converge or any numerical error happens
+ */
+ public double learnWeights(Sample... samples) throws Exception {
+
+ int iterations = 0;
+
+ double cost = Double.MAX_VALUE;
+ long start = System.currentTimeMillis();
+ while (true) {
+ if (iterations % (1 + (configuration.maxIterations / 100)) == 0) {
+ long time = (System.currentTimeMillis() - start) / 1000;
+ if (time > 60) {
+ System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
+ }
+ }
+ double newCost = 0;
+ RealMatrix x = MatrixUtils.createRealMatrix(samples.length, samples[0].getInputs().length);
+ RealMatrix y = MatrixUtils.createRealMatrix(samples.length, samples[0].getOutputs().length);
+ int i = 0;
+ for (Sample sample : samples) {
+ x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
+ y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
+ i++;
+ }
+
+ RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(weights[0].transpose()));
+ RealMatrix scores = hidden.multiply(weights[1].transpose());
+
+ RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
+
+ RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
+ correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return -Math.log(probs.getEntry(row, getMaxIndex(y.getRow(row))));
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+ double dataLoss = correctLogProbs.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+ private double d = 0;
+
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public void visit(int row, int column, double value) {
+ d += value;
+ }
+
+ @Override
+ public double end() {
+ return d;
+ }
+ }) / samples.length;
+
+ double reg = 0d;
+ reg += weights[0].walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+ private double d = 0d;
+
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public void visit(int row, int column, double value) {
+ d += Math.pow(value, 2);
+ }
+
+ @Override
+ public double end() {
+ return d;
+ }
+ });
+ newCost = dataLoss + 0.5 * 0.03 * reg;
+
+ if (Double.POSITIVE_INFINITY == newCost || newCost > cost) {
+ throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
+ } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
+ cost = newCost;
+ System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
+ break;
+ } else if (Double.isNaN(newCost)) {
+ throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost calculation underflow");
+ }
+
+ // update registered cost
+ cost = newCost;
+
+ // calculate the derivatives to update the parameters
+
+ RealMatrix dscores = probs;
+ dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ return y.getEntry(row, column) == 1 ? (value - 1) / samples.length : value / samples.length;
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+
+
+ RealMatrix dW2 = hidden.transpose().multiply(dscores);
+
+ dW2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ if (column != 0) {
+ return value + 0.03 * weights[1].transpose().getEntry(row, column);
+ } else {
+ return value;
+ }
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+
+ RealMatrix dhidden = dscores.multiply(weights[1]);
+
+ RealMatrix dW = x.transpose().multiply(dhidden);
+ dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ if (column != 0) {
+ return value + 0.03 * weights[0].transpose().getEntry(row, column);
+ } else {
+ return value;
+ }
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ });
+
+ RealMatrix[] derivatives = new RealMatrix[]{dW.transpose(), dW2.transpose()};
+
+ // update the weights
+ RealMatrix[] updatedParameters = new RealMatrix[weights.length];
+
+ for (int l = 0; l < weights.length; l++) {
+ RealMatrix realMatrix = weights[l].copy();
+ final int finalL = l;
+ RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
+
+ @Override
+ public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+ }
+
+ @Override
+ public double visit(int row, int column, double value) {
+ if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) {
+ return value - configuration.alpha * derivatives[finalL].getEntry(row, column);
+ } else {
+ return value;
+ }
+ }
+
+ @Override
+ public double end() {
+ return 0;
+ }
+ };
+ realMatrix.walkInOptimizedOrder(visitor);
+ updatedParameters[l] = realMatrix;
+ }
+ weights = updatedParameters;
+
+ iterations++;
+ }
+ return cost;
+ }
+
+ private int getMaxIndex(double[] array) {
+ double largest = array[0];
+ int index = 0;
+ for (int i = 1; i < array.length; i++) {
+ if (array[i] >= largest) {
+ largest = array[i];
+ index = i;
+ }
+ }
+ return index;
+ }
+
+ public static SkipGramNetwork.Builder newModel() {
+ return new Builder();
+ }
+
+ // --- skip gram neural network configuration ---
+
+ private static class Configuration {
+ // internal parameters
+ protected int outputs;
+ protected int inputs;
+
+ protected List<String> vocabulary;
+
+ // user controlled parameters
+ protected Path path;
+ protected int maxIterations;
+ protected double alpha = 0.001d;
+ protected double threshold = 0.004d;
+ protected int vectorSize;
+ protected int window;
+ }
+
+ public static class Builder {
+ private final Configuration configuration;
+
+ public Builder() {
+ this.configuration = new Configuration();
+ }
+
+
+ public Builder withWindow(int w) {
+ this.configuration.window = w;
+ return this;
+ }
+
+ public Builder fromTextAt(Path path) {
+ this.configuration.path = path;
+ return this;
+ }
+
+ public Builder withDimension(int d) {
+ this.configuration.vectorSize = d;
+ return this;
+ }
+
+ public SkipGramNetwork build() throws Exception {
+ System.out.println("reading fragments");
+ Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window);
+ assert !fragments.isEmpty() : "could not read fragments";
+ System.out.println("generating vocabulary");
+ List<String> vocabulary = getVocabulary(this.configuration.path);
+ assert !vocabulary.isEmpty() : "could not read vocabulary";
+ this.configuration.vocabulary = vocabulary;
+
+ System.out.println("creating training set");
+ Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window);
+ fragments.clear();
+ this.configuration.maxIterations = trainingSet.size() * 10;
+
+ HotEncodedSample next = trainingSet.iterator().next();
+
+ this.configuration.inputs = next.getInputs().length - 1;
+ this.configuration.outputs = next.getOutputs().length;
+
+ SkipGramNetwork network = new SkipGramNetwork(configuration);
+ network.learnWeights(trainingSet.toArray(new Sample[trainingSet.size()]));
+ return network;
+ }
+
+ private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
+ long start = System.currentTimeMillis();
+ Collection<HotEncodedSample> samples = new LinkedList<>();
+ List<byte[]> fragment;
+ while ((fragment = fragments.poll()) != null) {
+ byte[] inputWord = null;
+ List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
+ for (int i = 0; i < fragment.size(); i++) {
+ for (int j = 0; j < fragment.size(); j++) {
+ byte[] token = fragment.get(i);
+ if (i == j) {
+ inputWord = token;
+ } else {
+ outputWords.add(token);
+ }
+ }
+ }
+ final byte[] finalInputWord = inputWord;
+
+ double[] doubles = new double[window - 1];
+ for (int i = 0; i < doubles.length; i++) {
+ doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i)));
+ }
+
+ double[] inputs = new double[1];
+ inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord));
+
+ samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size()));
+
+ }
+
+ long end = System.currentTimeMillis();
+ System.out.println("training set created in " + (end - start) / 60000 + " minutes");
+
+ return samples;
+ }
+
+ private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException {
+ long start = System.currentTimeMillis();
+ Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>();
+
+ ByteBuffer buf = ByteBuffer.allocate(100);
+ try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+ String encoding = System.getProperty("file.encoding");
+ StringBuilder previous = new StringBuilder();
+ Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+ while (sbc.read(buf) > 0) {
+ buf.rewind();
+ CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+ String string = cleanString(charBuffer);
+ List<String> split = splitter.splitToList(string);
+ int splitSize = split.size();
+ if (splitSize > w) {
+ for (int j = 0; j < splitSize - w; j++) {
+ List<byte[]> fragment = new ArrayList<>(w);
+ fragment.add(previous.append(split.get(j)).toString().getBytes());
+ for (int i = 1; i < w; i++) {
+ fragment.add(split.get(i + j).getBytes());
+ }
+ // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
+ fragments.add(fragment);
+ previous = new StringBuilder();
+ }
+ previous = new StringBuilder().append(split.get(splitSize - 1));
+ } else if (split.size() == w) {
+ previous.append(string);
+ }
+ buf.flip();
+ }
+ } catch (IOException x) {
+ System.err.println("caught exception: " + x);
+ } finally {
+ buf.clear();
+ }
+ long end = System.currentTimeMillis();
+ System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
+ return fragments;
+ }
+
+ private List<String> getVocabulary(Path path) throws IOException {
+ Set<String> vocabulary = new HashSet<>();
+ ByteBuffer buf = ByteBuffer.allocate(100);
+ try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+ String encoding = System.getProperty("file.encoding");
+ StringBuilder previous = new StringBuilder();
+ Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults();
+ while (sbc.read(buf) > 0) {
+ buf.rewind();
+ CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+ String string = cleanString(charBuffer);
+ List<String> split = splitter.splitToList(string);
+ int splitSize = split.size();
+ if (splitSize > 1) {
+ String term = previous.append(split.get(0)).toString();
+ vocabulary.add(term.intern());
+ for (int i = 1; i < splitSize - 1; i++) {
+ String term2 = split.get(i);
+ vocabulary.add(term2.intern());
+ }
+ previous = new StringBuilder().append(split.get(splitSize - 1));
+ } else if (split.size() == 1) {
+ previous.append(string);
+ }
+ buf.flip();
+ }
+ } catch (IOException x) {
+ System.err.println("caught exception: " + x);
+ } finally {
+ buf.clear();
+ }
+ List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
+ Collections.sort(list);
+// for (String iw : vocabulary) {
+// System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list)));
+// }
+ return list;
+ }
+
+ private String cleanString(CharBuffer charBuffer) {
+ String s = charBuffer.toString();
+ return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", "");
+ }
+ }
+}
\ No newline at end of file
Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
------------------------------------------------------------------------------
svn:eol-style = native
Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java (from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java&r1=1724846&r2=1731036&rev=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java Thu Feb 18 10:13:42 2016
@@ -22,25 +22,27 @@ import org.apache.commons.math3.linear.A
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.Test;
+import java.util.Random;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
- * Tests for {@link ShallowFeedForwardNeuralNetwork}
+ * Tests for {@link MultiLayerNetwork}
*/
-public class ShallowFeedForwardNeuralNetworkTest {
+public class MultiLayerNetworkTest {
@Test
public void testLearnAndPredict() throws Exception {
- ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
- configuration.alpha = 0.0001d;
+ MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
+ configuration.alpha = 0.00001d;
configuration.layers = new int[]{3, 4, 1};
configuration.maxIterations = 10000;
- configuration.threshold = 0.004d;
+ configuration.threshold = 0.00000004d;
configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
- ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration);
+ MultiLayerNetwork neuralNetwork = new MultiLayerNetwork(configuration);
assertNotNull(neuralNetwork);
Sample[] samples = new Sample[3];
@@ -55,6 +57,21 @@ public class ShallowFeedForwardNeuralNet
assertNotNull(doubles);
assertEquals(0.9d, doubles[0], 0.2d);
+
+ samples = createRandomSamples(10000);
+ cost = neuralNetwork.learnWeights(samples);
+ assertTrue(cost > 0 && cost < 10);
+ }
+
+ private Sample[] createRandomSamples(int size) {
+ Random r = new Random();
+ Sample[] samples = new Sample[size];
+ for (int i = 0; i < size; i++) {
+ boolean l = r.nextBoolean();
+ samples[i] = new Sample(new double[]{r.nextDouble(), r.nextDouble(), r.nextDouble()}, l ? new double[]{1d} :
+ new double[]{0d});
+ }
+ return samples;
}
@Test
@@ -63,14 +80,14 @@ public class ShallowFeedForwardNeuralNet
RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights};
- ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+ MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
configuration.alpha = 0.0001d;
configuration.layers = new int[]{2, 1};
configuration.maxIterations = 10000;
configuration.threshold = 0.004d;
configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
- ShallowFeedForwardNeuralNetwork and = new ShallowFeedForwardNeuralNetwork(configuration, andRealMatrixSet);
+ MultiLayerNetwork and = new MultiLayerNetwork(configuration, andRealMatrixSet);
assertEquals(0L, Math.round(and.predictOutput(new double[]{1d, 0d})[0]));
assertEquals(0L, Math.round(and.predictOutput(new double[]{0d, 1d})[0]));
@@ -84,14 +101,14 @@ public class ShallowFeedForwardNeuralNet
RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights};
- ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+ MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
configuration.alpha = 0.0001d;
configuration.layers = new int[]{2, 1};
configuration.maxIterations = 10000;
configuration.threshold = 0.004d;
configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
- ShallowFeedForwardNeuralNetwork or = new ShallowFeedForwardNeuralNetwork(configuration, orRealMatrixSet);
+ MultiLayerNetwork or = new MultiLayerNetwork(configuration, orRealMatrixSet);
assertEquals(1L, Math.round(or.predictOutput(new double[]{1d, 0d})[0]));
assertEquals(1L, Math.round(or.predictOutput(new double[]{0d, 1d})[0]));
@@ -105,14 +122,14 @@ public class ShallowFeedForwardNeuralNet
RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights);
RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights};
- ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+ MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
configuration.alpha = 0.0001d;
configuration.layers = new int[]{1, 1};
configuration.maxIterations = 10000;
configuration.threshold = 0.004d;
configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
- ShallowFeedForwardNeuralNetwork not = new ShallowFeedForwardNeuralNetwork(configuration, notRealMatrixSet);
+ MultiLayerNetwork not = new MultiLayerNetwork(configuration, notRealMatrixSet);
assertEquals(1L, Math.round(not.predictOutput(new double[]{0d})[0]));
assertEquals(0L, Math.round(not.predictOutput(new double[]{1d})[0]));
}
@@ -123,14 +140,14 @@ public class ShallowFeedForwardNeuralNet
RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}});
RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights};
- ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+ MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
configuration.alpha = 0.0001d;
configuration.layers = new int[]{2, 2, 1};
configuration.maxIterations = 10000;
configuration.threshold = 0.004d;
configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
- ShallowFeedForwardNeuralNetwork nor = new ShallowFeedForwardNeuralNetwork(configuration, norRealMatrixSet);
+ MultiLayerNetwork nor = new MultiLayerNetwork(configuration, norRealMatrixSet);
assertEquals(0L, Math.round(nor.predictOutput(new double[]{1d, 0d})[0]));
Added: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.junit.Test;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Tests for skip gram network
+ */
+public class SkipGramNetworkTest {
+
+ @Test
+ public void testWordVectorsLearningOnAbstracts() throws Exception {
+ Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
+ SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+ RealMatrix wv = network.getWeights()[0];
+ List<String> vocabulary = network.getVocabulary();
+ serialize(vocabulary, wv);
+ measure(vocabulary, wv);
+ }
+
+ @Test
+ public void testWordVectorsLearningOnSentences() throws Exception {
+ Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
+ SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+ RealMatrix wv = network.getWeights()[0];
+ List<String> vocabulary = network.getVocabulary();
+ serialize(vocabulary, wv);
+ measure(vocabulary, wv);
+ }
+
+ @Test
+ public void testWordVectorsLearningOnTestData() throws Exception {
+ Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
+ SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+ RealMatrix wv = network.getWeights()[0];
+ List<String> vocabulary = network.getVocabulary();
+ serialize(vocabulary, wv);
+ measure(vocabulary, wv);
+ }
+
+ private void measure(List<String> vocabulary, RealMatrix wordVectors) {
+ System.out.println("measuring similarities");
+ Collection<DistanceMeasure> measures = new LinkedList<>();
+ measures.add(new EuclideanDistance());
+// measures.add(new DistanceMeasure() {
+// @Override
+// public double compute(double[] a, double[] b) {
+// double dp = 0.0;
+// double na = 0.0;
+// double nb = 0.0;
+// for (int i = 0; i < a.length; i++) {
+// dp += a[i] * b[i];
+// na += Math.pow(a[i], 2);
+// nb += Math.pow(b[i], 2);
+// }
+// double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+// return 1 / cosineSimilarity;
+// }
+//
+// @Override
+// public String toString() {
+// return "inverse cosine similarity distance measure";
+// }
+// });
+// measures.add((DistanceMeasure) (a, b) -> {
+// double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a)));
+// double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b)));
+// return Math.abs(db - da);
+// });
+ for (DistanceMeasure distanceMeasure : measures) {
+ System.out.println("computing similarity using " + distanceMeasure);
+ computeSimilarities(vocabulary, wordVectors, distanceMeasure);
+ }
+
+ }
+
+ private void serialize(List<String> vocabulary, RealMatrix wordVectors) throws IOException {
+ System.out.println("serializing word vectors");
+ BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
+ for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+ double[] a = wordVectors.getColumnVector(i).toArray();
+ String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
+ csq = csq.substring(1, csq.length() - 1);
+ bufferedWriter.append(csq);
+ bufferedWriter.append(", ");
+ bufferedWriter.append(vocabulary.get(i - 1));
+ bufferedWriter.newLine();
+ }
+ bufferedWriter.flush();
+ bufferedWriter.close();
+
+ // for post processing with dimensionality reduction (PCA, t-SNE, etc.):
+ // values: awk '{$hiddenSize=""; print $0}' target/sg-vectors.csv
+ // keys: awk '{print $hiddenSize}' target/sg-vectors.csv
+ }
+
+ private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) {
+ for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+ double[] subjectVector = wordVectors.getColumn(i);
+ subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length);
+ double maxSimilarity = -Double.MAX_VALUE;
+ double maxSimilarity1 = -Double.MAX_VALUE;
+ double maxSimilarity2 = -Double.MAX_VALUE;
+ int j0 = -1;
+ int j1 = -1;
+ int j2 = -1;
+ for (int j = 1; j < wordVectors.getColumnDimension(); j++) {
+ if (i != j) {
+ double[] vector = wordVectors.getColumn(j);
+ vector = Arrays.copyOfRange(vector, 1, vector.length);
+ double similarity = 1d / distanceMeasure.compute(subjectVector, vector);
+ if (similarity > maxSimilarity) {
+ maxSimilarity2 = maxSimilarity1;
+ j2 = j1;
+
+ maxSimilarity1 = maxSimilarity;
+ j1 = j0;
+
+ maxSimilarity = similarity;
+ j0 = j;
+ } else if (similarity > maxSimilarity1) {
+ maxSimilarity2 = maxSimilarity1;
+ j2 = j1;
+
+ maxSimilarity1 = similarity;
+ j1 = j;
+ } else if (similarity > maxSimilarity2) {
+ maxSimilarity2 = similarity;
+ j2 = j;
+ }
+ }
+ }
+ if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) {
+ System.out.println(vocabulary.get(i - 1) + " -> "
+ + vocabulary.get(j0 - 1) + ", "
+ + vocabulary.get(j1 - 1) + ", "
+ + vocabulary.get(j2 - 1));
+ } else {
+ System.err.println("no similarity for '" + vocabulary.get(i) + "' with " + distanceMeasure);
+ }
+ }
+ }
+}
Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
------------------------------------------------------------------------------
svn:eol-style = native
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org