You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2016/03/14 09:42:10 UTC
svn commit: r1734888 -
/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
Author: tommaso
Date: Mon Mar 14 08:42:10 2016
New Revision: 1734888
URL: http://svn.apache.org/viewvc?rev=1734888&view=rev
Log:
fixed bug on mini batch
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1734888&r1=1734887&r2=1734888&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Mon Mar 14 08:42:10 2016
@@ -19,7 +19,6 @@
package org.apache.yay;
import com.google.common.base.Splitter;
-import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
@@ -29,20 +28,15 @@ import org.apache.commons.math3.linear.R
import java.io.IOException;
import java.nio.ByteBuffer;
-import java.nio.CharBuffer;
import java.nio.channels.SeekableByteChannel;
-import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
-import java.util.Set;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.regex.Pattern;
@@ -236,24 +230,23 @@ public class SkipGramNetwork {
long start = System.currentTimeMillis();
int c = 1;
+ RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length);
+ RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length);
while (true) {
- RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length);
- RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length);
int i = 0;
for (int k = j * configuration.batchSize; k < j * configuration.batchSize + configuration.batchSize; k++) {
Sample sample = samples[k % samples.length];
- x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
- y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
+ x.setRow(i, sample.getInputs());
+ y.setRow(i, sample.getOutputs());
i++;
}
+ j++;
long time = (System.currentTimeMillis() - start) / 1000;
- if (iterations % (1 + (configuration.maxIterations / 100)) == 0 || time % 300 == 0) {
- if (time > 60 * c) {
- c += 1;
- System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
- }
+ if (iterations % (1 + (configuration.maxIterations / 100)) == 0 && time > 60 * c) {
+ c += 1;
+ System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
}
RealMatrix w0t = weights[0].transpose();
@@ -384,9 +377,9 @@ public class SkipGramNetwork {
System.out.println("started with cost = " + dataLoss + " + " + regLoss + " = " + newCost);
}
- if (Double.POSITIVE_INFINITY == newCost || newCost > cost) {
+ if (Double.POSITIVE_INFINITY == newCost) {
throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost);
- } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations || cost - newCost < configuration.threshold)) {
+ } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
cost = newCost;
System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
break;
@@ -1060,11 +1053,9 @@ public class SkipGramNetwork {
}
}
- List<String> os = new LinkedList<>();
double[] doubles = new double[window - 1];
for (int i = 0; i < doubles.length; i++) {
String o = new String(outputWords.get(i));
- os.add(o);
doubles[i] = (double) vocabulary.indexOf(o);
}
@@ -1143,91 +1134,6 @@ public class SkipGramNetwork {
}
}
- private Queue<List<byte[]>> getFragmentsOld(Path path, int w) throws IOException {
- long start = System.currentTimeMillis();
- Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>();
-
- ByteBuffer buf = ByteBuffer.allocate(100);
- try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
-
- String encoding = System.getProperty("file.encoding");
- StringBuilder previous = new StringBuilder();
- Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
- while (sbc.read(buf) > 0) {
- buf.rewind();
- CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
- String string = cleanString(charBuffer.toString());
- List<String> split = splitter.splitToList(string);
- int splitSize = split.size();
- if (splitSize > w) {
- for (int j = 0; j < splitSize - w; j++) {
- List<byte[]> fragment = new ArrayList<>(w);
- String str = split.get(j);
- fragment.add(previous.append(str).toString().getBytes());
- for (int i = 1; i < w; i++) {
- String s = split.get(i + j);
- fragment.add(s.getBytes());
- }
- // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration
- fragments.add(fragment);
- previous = new StringBuilder();
- }
- previous = new StringBuilder().append(split.get(splitSize - 1));
- } else if (split.size() == w) {
- previous.append(string);
- }
- buf.flip();
- }
- } catch (IOException x) {
- System.err.println("caught exception: " + x);
- } finally {
- buf.clear();
- }
- long end = System.currentTimeMillis();
- System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")");
- return fragments;
- }
-
- private List<String> getVocabulary(Path path) throws IOException {
- Set<String> vocabulary = new HashSet<>();
- ByteBuffer buf = ByteBuffer.allocate(100);
- try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
-
- String encoding = System.getProperty("file.encoding");
- StringBuilder previous = new StringBuilder();
- Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults();
- while (sbc.read(buf) > 0) {
- buf.rewind();
- CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
- String string = cleanString(charBuffer.toString());
- List<String> split = splitter.splitToList(string);
- int splitSize = split.size();
- if (splitSize > 1) {
- String term = previous.append(split.get(0)).toString();
- vocabulary.add(term.intern());
- for (int i = 1; i < splitSize - 1; i++) {
- String term2 = split.get(i);
- vocabulary.add(term2.intern());
- }
- previous = new StringBuilder().append(split.get(splitSize - 1));
- } else if (split.size() == 1) {
- previous.append(string);
- }
- buf.flip();
- }
- } catch (IOException x) {
- System.err.println("caught exception: " + x);
- } finally {
- buf.clear();
- }
- List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()]));
- Collections.sort(list);
-// for (String iw : vocabulary) {
-// System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list)));
-// }
- return list;
- }
-
private String cleanString(String s) {
return s.toLowerCase().replaceAll("\\.", " \\.").replaceAll("\\;", " \\;").replaceAll("\\,", " \\,").replaceAll("\\:", " \\:").replaceAll("\\-\\s", "").replaceAll("\\\"", " \\\"");
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org