You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/14 09:42:10 UTC

svn commit: r1734888 - /labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java

Author: tommaso
Date: Mon Mar 14 08:42:10 2016
New Revision: 1734888

URL: http://svn.apache.org/viewvc?rev=1734888&view=rev
Log:
fixed bug on mini batch

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1734888&r1=1734887&r2=1734888&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Mon Mar 14 08:42:10 2016
@@ -19,7 +19,6 @@
 package org.apache.yay;
 
 import com.google.common.base.Splitter;
-import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.math3.distribution.UniformRealDistribution;
 import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
@@ -29,20 +28,15 @@ import org.apache.commons.math3.linear.R
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.nio.CharBuffer;
 import java.nio.channels.SeekableByteChannel;
-import java.nio.charset.Charset;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Queue;
-import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedDeque;
 import java.util.regex.Pattern;
 
@@ -236,24 +230,23 @@ public class SkipGramNetwork {
 
     long start = System.currentTimeMillis();
     int c = 1;
+    RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length);
+    RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length);
     while (true) {
 
-      RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length);
-      RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length);
       int i = 0;
       for (int k = j * configuration.batchSize; k < j * configuration.batchSize + configuration.batchSize; k++) {
         Sample sample = samples[k % samples.length];
-        x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
-        y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
+        x.setRow(i, sample.getInputs());
+        y.setRow(i, sample.getOutputs());
         i++;
       }
+      j++;
 
       long time = (System.currentTimeMillis() - start) / 1000;
-      if (iterations % (1 + (configuration.maxIterations / 100)) == 0 || time % 300 == 0) {
-        if (time > 60 * c) {
-          c += 1;
-          System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
-        }
+      if (iterations % (1 + (configuration.maxIterations / 100)) == 0 && time > 60 * c) {
+        c += 1;
+        System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
       }
 
       RealMatrix w0t = weights[0].transpose();
@@ -384,9 +377,9 @@ public class SkipGramNetwork {
         System.out.println("started with cost = " + dataLoss + " + " + regLoss + " = " + newCost);
       }
 
-      if (Double.POSITIVE_INFINITY == newCost || newCost > cost) {
+      if (Double.POSITIVE_INFINITY == newCost) {
         throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
-      } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations || cost - newCost < configuration.threshold)) {
+      } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
         cost = newCost;
         System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
         break;
@@ -1060,11 +1053,9 @@ public class SkipGramNetwork {
           }
         }
 
-        List<String> os = new LinkedList<>();
         double[] doubles = new double[window - 1];
         for (int i = 0; i < doubles.length; i++) {
           String o = new String(outputWords.get(i));
-          os.add(o);
           doubles[i] = (double) vocabulary.indexOf(o);
         }
 
@@ -1143,91 +1134,6 @@ public class SkipGramNetwork {
       }
     }
 
-    private Queue<List<byte[]>> getFragmentsOld(Path path, int w) throws IOException {
-      long start = System.currentTimeMillis();
-      Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>();
-
-      ByteBuffer buf = ByteBuffer.allocate(100);
-      try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
-
-        String encoding = System.getProperty("file.encoding");
-        StringBuilder previous = new StringBuilder();
-        Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
-        while (sbc.read(buf) > 0) {
-          buf.rewind();
-          CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
-          String string = cleanString(charBuffer.toString());
-          List<String> split = splitter.splitToList(string);
-          int splitSize = split.size();
-          if (splitSize > w) {
-            for (int j = 0; j < splitSize - w; j++) {
-              List<byte[]> fragment = new ArrayList<>(w);
-              String str = split.get(j);
-              fragment.add(previous.append(str).toString().getBytes());
-              for (int i = 1; i < w; i++) {
-                String s = split.get(i + j);
-                fragment.add(s.getBytes());
-              }
-              // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
-              fragments.add(fragment);
-              previous = new StringBuilder();
-            }
-            previous = new StringBuilder().append(split.get(splitSize - 1));
-          } else if (split.size() == w) {
-            previous.append(string);
-          }
-          buf.flip();
-        }
-      } catch (IOException x) {
-        System.err.println("caught exception: " + x);
-      } finally {
-        buf.clear();
-      }
-      long end = System.currentTimeMillis();
-      System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
-      return fragments;
-    }
-
-    private List<String> getVocabulary(Path path) throws IOException {
-      Set<String> vocabulary = new HashSet<>();
-      ByteBuffer buf = ByteBuffer.allocate(100);
-      try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
-
-        String encoding = System.getProperty("file.encoding");
-        StringBuilder previous = new StringBuilder();
-        Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults();
-        while (sbc.read(buf) > 0) {
-          buf.rewind();
-          CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
-          String string = cleanString(charBuffer.toString());
-          List<String> split = splitter.splitToList(string);
-          int splitSize = split.size();
-          if (splitSize > 1) {
-            String term = previous.append(split.get(0)).toString();
-            vocabulary.add(term.intern());
-            for (int i = 1; i < splitSize - 1; i++) {
-              String term2 = split.get(i);
-              vocabulary.add(term2.intern());
-            }
-            previous = new StringBuilder().append(split.get(splitSize - 1));
-          } else if (split.size() == 1) {
-            previous.append(string);
-          }
-          buf.flip();
-        }
-      } catch (IOException x) {
-        System.err.println("caught exception: " + x);
-      } finally {
-        buf.clear();
-      }
-      List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
-      Collections.sort(list);
-//    for (String iw : vocabulary) {
-//      System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list)));
-//    }
-      return list;
-    }
-
     private String cleanString(String s) {
       return s.toLowerCase().replaceAll("\\.", " \\.").replaceAll("\\;", " \\;").replaceAll("\\,", " \\,").replaceAll("\\:", " \\:").replaceAll("\\-\\s", "").replaceAll("\\\"", " \\\"");
     }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org