You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/02/18 11:13:42 UTC

svn commit: r1731036 - in /labs/yay/trunk/core: ./ src/main/java/org/apache/yay/ src/test/java/org/apache/yay/

Author: tommaso
Date: Thu Feb 18 10:13:42 2016
New Revision: 1731036

URL: http://svn.apache.org/viewvc?rev=1731036&view=rev
Log:
refactored SFFNN to MLN, added ReLU function and (compact) skip-gram

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java
      - copied, changed from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java   (with props)
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java   (with props)
    labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
      - copied, changed from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java   (with props)
Removed:
    labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java
Modified:
    labs/yay/trunk/core/pom.xml

Modified: labs/yay/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1731036&r1=1731035&r2=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Thu Feb 18 10:13:42 2016
@@ -51,7 +51,6 @@
             <groupId>com.google.guava</groupId>
             <artifactId>guava</artifactId>
             <version>18.0</version>
-            <scope>test</scope>
         </dependency>
     </dependencies>
     <build>

Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java (from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java&r1=1724846&r2=1731036&rev=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java Thu Feb 18 10:13:42 2016
@@ -28,13 +28,11 @@ import java.util.Arrays;
 import java.util.Random;
 
 /**
- * A shallow feed forward neural network.
+ * A multi layer feed forward neural network.
  * It learns its weights through backpropagation algorithm via stochastic gradient descent applied to a collection of
  * training samples.
- * Each example is a real vectors whose first N elements (identified by the no. of network input units) are the actual
- * outputs and the remaining elements are the input features.
  */
-public class ShallowFeedForwardNeuralNetwork {
+public class MultiLayerNetwork {
 
   private final Configuration configuration;
 
@@ -51,12 +49,12 @@ public class ShallowFeedForwardNeuralNet
    */
   private RealMatrix[] weights;
 
-  public ShallowFeedForwardNeuralNetwork(Configuration configuration) {
+  public MultiLayerNetwork(Configuration configuration) {
     this.configuration = configuration;
     this.weights = createRandomWeights();
   }
 
-  public ShallowFeedForwardNeuralNetwork(Configuration configuration, RealMatrix[] weights) {
+  public MultiLayerNetwork(Configuration configuration, RealMatrix[] weights) {
     this.configuration = configuration;
     this.weights = weights;
   }
@@ -143,7 +141,7 @@ public class ShallowFeedForwardNeuralNet
 
       if (Double.POSITIVE_INFINITY == newCost) {
         throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
-      } else if (iterations > 1 && (cost == newCost || newCost < configuration.threshold || iterations > configuration.maxIterations)) {
+      } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
         System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
         break;
       } else if (Double.isNaN(newCost)) {
@@ -204,13 +202,13 @@ public class ShallowFeedForwardNeuralNet
     }
 
     // compute derivatives
-    for (int i = 0; i < deltas.length; i++) {
-      ds[i] = deltas[i].scalarMultiply(1d / size);
-    }
+//    for (int i = 0; i < deltas.length; i++) {
+//      ds[i] = deltas[i].scalarMultiply(1d / size);
+//    }
 
     // regularization
     int l = 0;
-    for (RealMatrix d : ds) {
+    for (RealMatrix d : deltas) {
       final int finalL = l;
       d.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
         @Override
@@ -221,7 +219,7 @@ public class ShallowFeedForwardNeuralNet
         @Override
         public double visit(int row, int column, double value) {
           if (column != 0) {
-            return value + configuration.alpha * weights[finalL].getEntry(row, column);
+            return value + configuration.alpha * weights[finalL].getEntry(row, column); // assuming regularization factor == learning rate
           } else {
             return value;
           }
@@ -235,7 +233,7 @@ public class ShallowFeedForwardNeuralNet
       l++;
     }
 
-    return ds;
+    return deltas;
   }
 
   private RealVector calculateDeltaVector(RealMatrix weight, RealVector activationsVector, RealVector nextLayerDelta) {
@@ -294,7 +292,17 @@ public class ShallowFeedForwardNeuralNet
       res += yo * Math.log(ho) + (1d - yo)
               * Math.log(1d - ho);
     }
+
     return (-1d / size) * res;
+
+//    Double res = 0d;
+//
+//    for (int i = 0; i < predictedOutput.length; i++) {
+//      Double so = expectedOutput[i];
+//      Double po = predictedOutput[i];
+//      res -= so * Math.log(po);
+//    }
+//    return res;
   }
 
   // --- feed forward ---

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+
+/**
+ * Rectifier (aka ReLU) activation function
+ */
+public class RectifierFunction implements ActivationFunction {
+  @Override
+  public RealMatrix applyMatrix(RealMatrix weights) {
+    RealMatrix matrix = weights.copy();
+    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+      @Override
+      public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+      }
+
+      @Override
+      public double visit(int row, int column, double value) {
+        return Math.max(0, value);
+      }
+
+      @Override
+      public double end() {
+        return 0;
+      }
+    });
+    return matrix;
+  }
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,549 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import com.google.common.base.Splitter;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
+import org.apache.commons.math3.linear.RealMatrixPreservingVisitor;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.regex.Pattern;
+
+/**
+ * A skip-gram neural network.
+ * It learns its weights through backpropagation algorithm via batch gradient descent applied to a collection of
+ * hot encoded training samples.
+ */
+public class SkipGramNetwork {
+
+  private final Configuration configuration;
+  private final RectifierFunction rectifierFunction = new RectifierFunction();
+  private final SoftmaxActivationFunction softmaxActivationFunction = new SoftmaxActivationFunction();
+
+  /**
+   * Each RealMatrix maps weights between two layers.
+   * E.g.: weights[0] controls function mapping from layer 0 to layer 1.
+   * If network has 4 units in layer 1 and 5 units in layer 2, then weights[0] will be of dimension 5x4, plus bias terms.
+   * A network having layers with 3, 4 and 2 units each will have the following weights matrix dimensions:
+   * - weights[0] : 4x3
+   * - weights[1] : 2x4
+   * <p>
+   * the first row of weighs[0] matrix holds the weights of each neuron in the first neuron of the second layer,
+   * the second row of weighs[0] holds the weights of each neuron in the second neuron of the second layer, etc.
+   */
+  private RealMatrix[] weights;
+
+
+  private SkipGramNetwork(Configuration configuration) {
+    this.configuration = configuration;
+    this.weights = createRandomWeights();
+  }
+
+  public RealMatrix[] getWeights() {
+    return weights;
+  }
+
+  public List<String> getVocabulary() {
+    return configuration.vocabulary;
+  }
+
+  private RealMatrix[] createRandomWeights() {
+    Random r = new Random();
+    int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs};
+    int[] layers = new int[conf.length];
+    for (int i = 0; i < layers.length; i++) {
+      layers[i] = conf[i] + (i < layers.length - 1 ? 1 : 0);
+    }
+    int weightsCount = layers.length - 1;
+
+    RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+
+    for (int i = 0; i < weightsCount; i++) {
+
+      RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]);
+      final int finalI = i;
+      matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          if (finalI != weightsCount - 1 && row == 0) {
+            return 0d;
+          } else if (column == 0) {
+            return 1d;
+          }
+          return r.nextInt(100) / 101d;
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      initialWeights[i] = matrix;
+    }
+    return initialWeights;
+
+  }
+
+  // --- batch gradient descent ---
+
+  /**
+   * perform weights learning from the training examples using batch gradient descent algorithm
+   *
+   * @param samples the training examples
+   * @return the final cost with the updated weights
+   * @throws Exception if BGD fails to converge or any numerical error happens
+   */
+  public double learnWeights(Sample... samples) throws Exception {
+
+    int iterations = 0;
+
+    double cost = Double.MAX_VALUE;
+    long start = System.currentTimeMillis();
+    while (true) {
+      if (iterations % (1 + (configuration.maxIterations / 100)) == 0) {
+        long time = (System.currentTimeMillis() - start) / 1000;
+        if (time > 60) {
+          System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
+        }
+      }
+      double newCost = 0;
+      RealMatrix x = MatrixUtils.createRealMatrix(samples.length, samples[0].getInputs().length);
+      RealMatrix y = MatrixUtils.createRealMatrix(samples.length, samples[0].getOutputs().length);
+      int i = 0;
+      for (Sample sample : samples) {
+        x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
+        y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
+        i++;
+      }
+
+      RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(weights[0].transpose()));
+      RealMatrix scores = hidden.multiply(weights[1].transpose());
+
+      RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
+
+      RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
+      correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return -Math.log(probs.getEntry(row, getMaxIndex(y.getRow(row))));
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+      double dataLoss = correctLogProbs.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+        private double d = 0;
+
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public void visit(int row, int column, double value) {
+          d += value;
+        }
+
+        @Override
+        public double end() {
+          return d;
+        }
+      }) / samples.length;
+
+      double reg = 0d;
+      reg += weights[0].walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+        private double d = 0d;
+
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public void visit(int row, int column, double value) {
+          d += Math.pow(value, 2);
+        }
+
+        @Override
+        public double end() {
+          return d;
+        }
+      });
+      newCost = dataLoss + 0.5 * 0.03 * reg;
+
+      if (Double.POSITIVE_INFINITY == newCost || newCost > cost) {
+        throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
+      } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
+        cost = newCost;
+        System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
+        break;
+      } else if (Double.isNaN(newCost)) {
+        throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost calculation underflow");
+      }
+
+      // update registered cost
+      cost = newCost;
+
+      // calculate the derivatives to update the parameters
+
+      RealMatrix dscores = probs;
+      dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return y.getEntry(row, column) == 1 ? (value - 1) / samples.length : value / samples.length;
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+
+      RealMatrix dW2 = hidden.transpose().multiply(dscores);
+
+      dW2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          if (column != 0) {
+            return value + 0.03 * weights[1].transpose().getEntry(row, column);
+          } else {
+            return value;
+          }
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      RealMatrix dhidden = dscores.multiply(weights[1]);
+
+      RealMatrix dW = x.transpose().multiply(dhidden);
+      dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          if (column != 0) {
+            return value + 0.03 * weights[0].transpose().getEntry(row, column);
+          } else {
+            return value;
+          }
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      RealMatrix[] derivatives = new RealMatrix[]{dW.transpose(), dW2.transpose()};
+
+      // update the weights
+      RealMatrix[] updatedParameters = new RealMatrix[weights.length];
+
+      for (int l = 0; l < weights.length; l++) {
+        RealMatrix realMatrix = weights[l].copy();
+        final int finalL = l;
+        RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
+
+          @Override
+          public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {
+
+          }
+
+          @Override
+          public double visit(int row, int column, double value) {
+            if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) {
+              return value - configuration.alpha * derivatives[finalL].getEntry(row, column);
+            } else {
+              return value;
+            }
+          }
+
+          @Override
+          public double end() {
+            return 0;
+          }
+        };
+        realMatrix.walkInOptimizedOrder(visitor);
+        updatedParameters[l] = realMatrix;
+      }
+      weights = updatedParameters;
+
+      iterations++;
+    }
+    return cost;
+  }
+
+  private int getMaxIndex(double[] array) {
+    double largest = array[0];
+    int index = 0;
+    for (int i = 1; i < array.length; i++) {
+      if (array[i] >= largest) {
+        largest = array[i];
+        index = i;
+      }
+    }
+    return index;
+  }
+
+  public static SkipGramNetwork.Builder newModel() {
+    return new Builder();
+  }
+
+  // --- skip gram neural network configuration ---
+
+  private static class Configuration {
+    // internal parameters
+    protected int outputs;
+    protected int inputs;
+
+    protected List<String> vocabulary;
+
+    // user controlled parameters
+    protected Path path;
+    protected int maxIterations;
+    protected double alpha = 0.001d;
+    protected double threshold = 0.004d;
+    protected int vectorSize;
+    protected int window;
+  }
+
+  public static class Builder {
+    private final Configuration configuration;
+
+    public Builder() {
+      this.configuration = new Configuration();
+    }
+
+
+    public Builder withWindow(int w) {
+      this.configuration.window = w;
+      return this;
+    }
+
+    public Builder fromTextAt(Path path) {
+      this.configuration.path = path;
+      return this;
+    }
+
+    public Builder withDimension(int d) {
+      this.configuration.vectorSize = d;
+      return this;
+    }
+
+    public SkipGramNetwork build() throws Exception {
+      System.out.println("reading fragments");
+      Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window);
+      assert !fragments.isEmpty() : "could not read fragments";
+      System.out.println("generating vocabulary");
+      List<String> vocabulary = getVocabulary(this.configuration.path);
+      assert !vocabulary.isEmpty() : "could not read vocabulary";
+      this.configuration.vocabulary = vocabulary;
+
+      System.out.println("creating training set");
+      Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window);
+      fragments.clear();
+      this.configuration.maxIterations = trainingSet.size() * 10;
+
+      HotEncodedSample next = trainingSet.iterator().next();
+
+      this.configuration.inputs = next.getInputs().length - 1;
+      this.configuration.outputs = next.getOutputs().length;
+
+      SkipGramNetwork network = new SkipGramNetwork(configuration);
+      network.learnWeights(trainingSet.toArray(new Sample[trainingSet.size()]));
+      return network;
+    }
+
+    private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
+      long start = System.currentTimeMillis();
+      Collection<HotEncodedSample> samples = new LinkedList<>();
+      List<byte[]> fragment;
+      while ((fragment = fragments.poll()) != null) {
+        byte[] inputWord = null;
+        List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
+        for (int i = 0; i < fragment.size(); i++) {
+          for (int j = 0; j < fragment.size(); j++) {
+            byte[] token = fragment.get(i);
+            if (i == j) {
+              inputWord = token;
+            } else {
+              outputWords.add(token);
+            }
+          }
+        }
+        final byte[] finalInputWord = inputWord;
+
+        double[] doubles = new double[window - 1];
+        for (int i = 0; i < doubles.length; i++) {
+          doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i)));
+        }
+
+        double[] inputs = new double[1];
+        inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord));
+
+        samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size()));
+
+      }
+
+      long end = System.currentTimeMillis();
+      System.out.println("training set created in " + (end - start) / 60000 + " minutes");
+
+      return samples;
+    }
+
+    private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException {
+      long start = System.currentTimeMillis();
+      Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>();
+
+      ByteBuffer buf = ByteBuffer.allocate(100);
+      try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+        String encoding = System.getProperty("file.encoding");
+        StringBuilder previous = new StringBuilder();
+        Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+        while (sbc.read(buf) > 0) {
+          buf.rewind();
+          CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+          String string = cleanString(charBuffer);
+          List<String> split = splitter.splitToList(string);
+          int splitSize = split.size();
+          if (splitSize > w) {
+            for (int j = 0; j < splitSize - w; j++) {
+              List<byte[]> fragment = new ArrayList<>(w);
+              fragment.add(previous.append(split.get(j)).toString().getBytes());
+              for (int i = 1; i < w; i++) {
+                fragment.add(split.get(i + j).getBytes());
+              }
+              // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
+              fragments.add(fragment);
+              previous = new StringBuilder();
+            }
+            previous = new StringBuilder().append(split.get(splitSize - 1));
+          } else if (split.size() == w) {
+            previous.append(string);
+          }
+          buf.flip();
+        }
+      } catch (IOException x) {
+        System.err.println("caught exception: " + x);
+      } finally {
+        buf.clear();
+      }
+      long end = System.currentTimeMillis();
+      System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
+      return fragments;
+    }
+
+    private List<String> getVocabulary(Path path) throws IOException {
+      Set<String> vocabulary = new HashSet<>();
+      ByteBuffer buf = ByteBuffer.allocate(100);
+      try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+        String encoding = System.getProperty("file.encoding");
+        StringBuilder previous = new StringBuilder();
+        Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults();
+        while (sbc.read(buf) > 0) {
+          buf.rewind();
+          CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+          String string = cleanString(charBuffer);
+          List<String> split = splitter.splitToList(string);
+          int splitSize = split.size();
+          if (splitSize > 1) {
+            String term = previous.append(split.get(0)).toString();
+            vocabulary.add(term.intern());
+            for (int i = 1; i < splitSize - 1; i++) {
+              String term2 = split.get(i);
+              vocabulary.add(term2.intern());
+            }
+            previous = new StringBuilder().append(split.get(splitSize - 1));
+          } else if (split.size() == 1) {
+            previous.append(string);
+          }
+          buf.flip();
+        }
+      } catch (IOException x) {
+        System.err.println("caught exception: " + x);
+      } finally {
+        buf.clear();
+      }
+      List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
+      Collections.sort(list);
+//    for (String iw : vocabulary) {
+//      System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list)));
+//    }
+      return list;
+    }
+
+    private String cleanString(CharBuffer charBuffer) {
+      String s = charBuffer.toString();
+      return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", "");
+    }
+  }
+}
\ No newline at end of file

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
------------------------------------------------------------------------------
    svn:eol-style = native

Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java (from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java)
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java&r1=1724846&r2=1731036&rev=1731036&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java Thu Feb 18 10:13:42 2016
@@ -22,25 +22,27 @@ import org.apache.commons.math3.linear.A
 import org.apache.commons.math3.linear.RealMatrix;
 import org.junit.Test;
 
+import java.util.Random;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 
 /**
- * Tests for {@link ShallowFeedForwardNeuralNetwork}
+ * Tests for {@link MultiLayerNetwork}
  */
-public class ShallowFeedForwardNeuralNetworkTest {
+public class MultiLayerNetworkTest {
 
   @Test
   public void testLearnAndPredict() throws Exception {
-    ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
-    configuration.alpha = 0.0001d;
+    MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
+    configuration.alpha = 0.00001d;
     configuration.layers = new int[]{3, 4, 1};
     configuration.maxIterations = 10000;
-    configuration.threshold = 0.004d;
+    configuration.threshold = 0.00000004d;
     configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
 
-    ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration);
+    MultiLayerNetwork neuralNetwork = new MultiLayerNetwork(configuration);
 
     assertNotNull(neuralNetwork);
     Sample[] samples = new Sample[3];
@@ -55,6 +57,21 @@ public class ShallowFeedForwardNeuralNet
     assertNotNull(doubles);
 
     assertEquals(0.9d, doubles[0], 0.2d);
+
+    samples = createRandomSamples(10000);
+    cost = neuralNetwork.learnWeights(samples);
+    assertTrue(cost > 0 && cost < 10);
+  }
+
+  private Sample[] createRandomSamples(int size) {
+    Random r = new Random();
+    Sample[] samples = new Sample[size];
+    for (int i = 0; i < size; i++) {
+      boolean l = r.nextBoolean();
+      samples[i] = new Sample(new double[]{r.nextDouble(), r.nextDouble(), r.nextDouble()}, l ? new double[]{1d} :
+              new double[]{0d});
+    }
+    return samples;
   }
 
   @Test
@@ -63,14 +80,14 @@ public class ShallowFeedForwardNeuralNet
     RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights};
 
-    ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+    MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
     configuration.alpha = 0.0001d;
     configuration.layers = new int[]{2, 1};
     configuration.maxIterations = 10000;
     configuration.threshold = 0.004d;
     configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
 
-    ShallowFeedForwardNeuralNetwork and = new ShallowFeedForwardNeuralNetwork(configuration, andRealMatrixSet);
+    MultiLayerNetwork and = new MultiLayerNetwork(configuration, andRealMatrixSet);
 
     assertEquals(0L, Math.round(and.predictOutput(new double[]{1d, 0d})[0]));
     assertEquals(0L, Math.round(and.predictOutput(new double[]{0d, 1d})[0]));
@@ -84,14 +101,14 @@ public class ShallowFeedForwardNeuralNet
     RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights};
 
-    ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+    MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
     configuration.alpha = 0.0001d;
     configuration.layers = new int[]{2, 1};
     configuration.maxIterations = 10000;
     configuration.threshold = 0.004d;
     configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
 
-    ShallowFeedForwardNeuralNetwork or = new ShallowFeedForwardNeuralNetwork(configuration, orRealMatrixSet);
+    MultiLayerNetwork or = new MultiLayerNetwork(configuration, orRealMatrixSet);
 
     assertEquals(1L, Math.round(or.predictOutput(new double[]{1d, 0d})[0]));
     assertEquals(1L, Math.round(or.predictOutput(new double[]{0d, 1d})[0]));
@@ -105,14 +122,14 @@ public class ShallowFeedForwardNeuralNet
     RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights);
     RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights};
 
-    ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+    MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
     configuration.alpha = 0.0001d;
     configuration.layers = new int[]{1, 1};
     configuration.maxIterations = 10000;
     configuration.threshold = 0.004d;
     configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
 
-    ShallowFeedForwardNeuralNetwork not = new ShallowFeedForwardNeuralNetwork(configuration, notRealMatrixSet);
+    MultiLayerNetwork not = new MultiLayerNetwork(configuration, notRealMatrixSet);
     assertEquals(1L, Math.round(not.predictOutput(new double[]{0d})[0]));
     assertEquals(0L, Math.round(not.predictOutput(new double[]{1d})[0]));
   }
@@ -123,14 +140,14 @@ public class ShallowFeedForwardNeuralNet
     RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}});
     RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights};
 
-    ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration();
+    MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration();
     configuration.alpha = 0.0001d;
     configuration.layers = new int[]{2, 2, 1};
     configuration.maxIterations = 10000;
     configuration.threshold = 0.004d;
     configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()};
 
-    ShallowFeedForwardNeuralNetwork nor = new ShallowFeedForwardNeuralNetwork(configuration, norRealMatrixSet);
+    MultiLayerNetwork nor = new MultiLayerNetwork(configuration, norRealMatrixSet);
 
 
     assertEquals(0L, Math.round(nor.predictOutput(new double[]{1d, 0d})[0]));

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1731036&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Feb 18 10:13:42 2016
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.junit.Test;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Tests for skip gram network
+ */
+public class SkipGramNetworkTest {
+
+  @Test
+  public void testWordVectorsLearningOnAbstracts() throws Exception {
+    Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
+    SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    RealMatrix wv = network.getWeights()[0];
+    List<String> vocabulary = network.getVocabulary();
+    serialize(vocabulary, wv);
+    measure(vocabulary, wv);
+  }
+
+  @Test
+  public void testWordVectorsLearningOnSentences() throws Exception {
+    Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
+    SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    RealMatrix wv = network.getWeights()[0];
+    List<String> vocabulary = network.getVocabulary();
+    serialize(vocabulary, wv);
+    measure(vocabulary, wv);
+  }
+
+  @Test
+  public void testWordVectorsLearningOnTestData() throws Exception {
+    Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
+    SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    RealMatrix wv = network.getWeights()[0];
+    List<String> vocabulary = network.getVocabulary();
+    serialize(vocabulary, wv);
+    measure(vocabulary, wv);
+  }
+
+  private void measure(List<String> vocabulary, RealMatrix wordVectors) {
+    System.out.println("measuring similarities");
+    Collection<DistanceMeasure> measures = new LinkedList<>();
+    measures.add(new EuclideanDistance());
+//    measures.add(new DistanceMeasure() {
+//      @Override
+//      public double compute(double[] a, double[] b) {
+//        double dp = 0.0;
+//        double na = 0.0;
+//        double nb = 0.0;
+//        for (int i = 0; i < a.length; i++) {
+//          dp += a[i] * b[i];
+//          na += Math.pow(a[i], 2);
+//          nb += Math.pow(b[i], 2);
+//        }
+//        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+//        return 1 / cosineSimilarity;
+//      }
+//
+//      @Override
+//      public String toString() {
+//        return "inverse cosine similarity distance measure";
+//      }
+//    });
+//    measures.add((DistanceMeasure) (a, b) -> {
+//      double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a)));
+//      double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b)));
+//      return Math.abs(db - da);
+//    });
+    for (DistanceMeasure distanceMeasure : measures) {
+      System.out.println("computing similarity using " + distanceMeasure);
+      computeSimilarities(vocabulary, wordVectors, distanceMeasure);
+    }
+
+  }
+
+  private void serialize(List<String> vocabulary, RealMatrix wordVectors) throws IOException {
+    System.out.println("serializing word vectors");
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv")));
+    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+      double[] a = wordVectors.getColumnVector(i).toArray();
+      String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
+      csq = csq.substring(1, csq.length() - 1);
+      bufferedWriter.append(csq);
+      bufferedWriter.append(", ");
+      bufferedWriter.append(vocabulary.get(i - 1));
+      bufferedWriter.newLine();
+    }
+    bufferedWriter.flush();
+    bufferedWriter.close();
+
+    // for post processing with dimensionality reduction (PCA, t-SNE, etc.):
+    // values: awk '{$hiddenSize=""; print $0}' target/sg-vectors.csv
+    // keys: awk '{print $hiddenSize}' target/sg-vectors.csv
+  }
+
+  private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) {
+    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+      double[] subjectVector = wordVectors.getColumn(i);
+      subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length);
+      double maxSimilarity = -Double.MAX_VALUE;
+      double maxSimilarity1 = -Double.MAX_VALUE;
+      double maxSimilarity2 = -Double.MAX_VALUE;
+      int j0 = -1;
+      int j1 = -1;
+      int j2 = -1;
+      for (int j = 1; j < wordVectors.getColumnDimension(); j++) {
+        if (i != j) {
+          double[] vector = wordVectors.getColumn(j);
+          vector = Arrays.copyOfRange(vector, 1, vector.length);
+          double similarity = 1d / distanceMeasure.compute(subjectVector, vector);
+          if (similarity > maxSimilarity) {
+            maxSimilarity2 = maxSimilarity1;
+            j2 = j1;
+
+            maxSimilarity1 = maxSimilarity;
+            j1 = j0;
+
+            maxSimilarity = similarity;
+            j0 = j;
+          } else if (similarity > maxSimilarity1) {
+            maxSimilarity2 = maxSimilarity1;
+            j2 = j1;
+
+            maxSimilarity1 = similarity;
+            j1 = j;
+          } else if (similarity > maxSimilarity2) {
+            maxSimilarity2 = similarity;
+            j2 = j;
+          }
+        }
+      }
+      if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) {
+        System.out.println(vocabulary.get(i - 1) + " -> "
+                + vocabulary.get(j0 - 1) + ", "
+                + vocabulary.get(j1 - 1) + ", "
+                + vocabulary.get(j2 - 1));
+      } else {
+        System.err.println("no similarity for '" + vocabulary.get(i) + "' with " + distanceMeasure);
+      }
+    }
+  }
+}

Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
------------------------------------------------------------------------------
    svn:eol-style = native




---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org