You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/08/30 21:04:46 UTC
[01/17] incubator-joshua git commit: Fix a number of issues: - Reader
now implements autocloseable - Close various leaks from LineReader -
LineReader no longer implements custom finalize(). Resources should be
explicitly closed when no longer needed. T
Repository: incubator-joshua
Updated Branches:
refs/heads/7 bb3c3004e -> b0b706272
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/zmert/MertCore.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/zmert/MertCore.java b/src/main/java/org/apache/joshua/zmert/MertCore.java
index c0d470d..00398b3 100644
--- a/src/main/java/org/apache/joshua/zmert/MertCore.java
+++ b/src/main/java/org/apache/joshua/zmert/MertCore.java
@@ -53,13 +53,14 @@ import org.apache.joshua.decoder.Decoder;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.metrics.EvaluationMetric;
import org.apache.joshua.util.StreamGobbler;
+import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This code was originally written by Omar Zaidan. In September of 2012, it was augmented to support
* a sparse feature implementation.
- *
+ *
* @author Omar Zaidan
*/
@@ -71,7 +72,6 @@ public class MertCore {
private TreeSet<Integer>[] indicesOfInterest_all;
private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
private final static double NegInf = (-1.0 / 0.0);
private final static double PosInf = (+1.0 / 0.0);
@@ -255,26 +255,26 @@ public class MertCore {
// private int useDisk;
- public MertCore(JoshuaConfiguration joshuaConfiguration)
+ public MertCore(JoshuaConfiguration joshuaConfiguration)
{
this.joshuaConfiguration = joshuaConfiguration;
}
- public MertCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
+ public MertCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(args);
initialize(0);
}
- public MertCore(String configFileName,JoshuaConfiguration joshuaConfiguration) {
+ public MertCore(String configFileName,JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(cfgFileToArgsArray(configFileName));
initialize(0);
}
- private void initialize(int randsToSkip) {
+ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
randGen = new Random(seed);
@@ -298,12 +298,12 @@ public class MertCore {
if (! new File(refFile).exists())
refFile = refFileName + ".0";
if (! new File(refFile).exists()) {
- throw new RuntimeException(String.format("* FATAL: can't find first reference file '%s{0,.0}'", refFileName));
+ throw new IOException(String.format("* FATAL: can't find first reference file '%s{0,.0}'", refFileName));
}
- numSentences = countLines(refFile);
+ numSentences = new ExistingUTF8EncodedTextFile(refFile).getNumberOfLines();
} else {
- numSentences = countLines(refFileName);
+ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines();
}
processDocInfo();
@@ -315,7 +315,7 @@ public class MertCore {
- numParams = countNonEmptyLines(paramsFileName) - 1;
+ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
// the parameter file contains one line per parameter
// and one line for the normalization method
@@ -375,7 +375,7 @@ public class MertCore {
reference_readers[i] = new BufferedReader(new InputStreamReader(new FileInputStream(new File(refFile)), "utf8"));
}
}
-
+
for (int i = 0; i < numSentences; ++i) {
for (int r = 0; r < refsPerSen; ++r) {
// read the rth reference translation for the ith sentence
@@ -384,7 +384,7 @@ public class MertCore {
}
// close all the reference files
- for (int i = 0; i < refsPerSen; i++)
+ for (int i = 0; i < refsPerSen; i++)
reference_readers[i].close();
// read in decoder command, if any
@@ -1522,10 +1522,10 @@ public class MertCore {
/*
* line format:
- *
+ *
* i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
* .*
- *
+ *
* Updated September 2012: features can now be named (for sparse feature compatibility).
* You must name all features or none of them.
*/
@@ -1827,7 +1827,7 @@ public class MertCore {
// belongs to,
// and its order in that document. (can also use '-' instead of '_')
- int docInfoSize = countNonEmptyLines(docInfoFileName);
+ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
if (docInfoSize < numSentences) { // format #1 or #2
numDocuments = docInfoSize;
@@ -1935,13 +1935,13 @@ public class MertCore {
/*
* InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
* new BufferedReader(new InputStreamReader(inStream, "utf8"));
- *
+ *
* FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
* outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
* BufferedWriter(outStreamWriter);
- *
+ *
* String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
- *
+ *
* inFile.close(); outFile.close();
*/
return true;
@@ -2454,12 +2454,12 @@ public class MertCore {
/*
* 1: -docSet bottom 8d 2: -docSet bottom 25% the bottom ceil(0.20*numDocs) documents 3: -docSet
* top 8d 4: -docSet top 25% the top ceil(0.20*numDocs) documents
- *
+ *
* 5: -docSet window 11d around 90percentile 11 docs centered around 80th percentile (complain
* if not enough docs; don't adjust) 6: -docSet window 11d around 40rank 11 docs centered around
* doc ranked 50 (complain if not enough docs; don't adjust)
- *
- *
+ *
+ *
* [0]: method (0-6) [1]: first (1-indexed) [2]: last (1-indexed) [3]: size [4]: center [5]:
* arg1 (-1 for method 0) [6]: arg2 (-1 for methods 0-4)
*/
@@ -2505,7 +2505,6 @@ public class MertCore {
info[1] = info[4] - ((info[3] - 1) / 2);
info[2] = info[4] + ((info[3] - 1) / 2);
}
-
}
private void checkFile(String fileName) {
@@ -2688,108 +2687,11 @@ public class MertCore {
return str;
}
- private int countLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null) ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
- private int countNonEmptyLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null && line.length() > 0) ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
private String fullPath(String dir, String fileName) {
File dummyFile = new File(dir, fileName);
return dummyFile.getAbsolutePath();
}
- @SuppressWarnings("unused")
- private void cleanupMemory() {
- cleanupMemory(100, false);
- }
-
- @SuppressWarnings("unused")
- private void cleanupMemorySilently() {
- cleanupMemory(100, true);
- }
-
- @SuppressWarnings("static-access")
- private void cleanupMemory(int reps, boolean silent) {
- int bytesPerMB = 1024 * 1024;
-
- long totalMemBefore = myRuntime.totalMemory();
- long freeMemBefore = myRuntime.freeMemory();
- long usedMemBefore = totalMemBefore - freeMemBefore;
-
-
- long usedCurr = usedMemBefore;
- long usedPrev = usedCurr;
-
- // perform garbage collection repeatedly, until there is no decrease in
- // the amount of used memory
- for (int i = 1; i <= reps; ++i) {
- myRuntime.runFinalization();
- myRuntime.gc();
- (Thread.currentThread()).yield();
-
- usedPrev = usedCurr;
- usedCurr = myRuntime.totalMemory() - myRuntime.freeMemory();
-
- if (usedCurr == usedPrev) break;
- }
-
-
- if (!silent) {
- long totalMemAfter = myRuntime.totalMemory();
- long freeMemAfter = myRuntime.freeMemory();
- long usedMemAfter = totalMemAfter - freeMemAfter;
-
- println("GC: d_used = " + ((usedMemAfter - usedMemBefore) / bytesPerMB) + " MB "
- + "(d_tot = " + ((totalMemAfter - totalMemBefore) / bytesPerMB) + " MB).", 2);
- }
- }
-
- @SuppressWarnings("unused")
- private void printMemoryUsage() {
- int bytesPerMB = 1024 * 1024;
- long totalMem = myRuntime.totalMemory();
- long freeMem = myRuntime.freeMemory();
- long usedMem = totalMem - freeMem;
-
- println("Allocated memory: " + (totalMem / bytesPerMB) + " MB " + "(of which "
- + (usedMem / bytesPerMB) + " MB is being used).", 2);
- }
-
private void println(Object obj, int priority) {
if (priority <= verbosity) println(obj);
}
@@ -2887,58 +2789,13 @@ public class MertCore {
lastUsedIndex[i] += 1;
}
- @SuppressWarnings("unused")
- private HashSet<Integer> indicesToDiscard(double[] slope, double[] offset) {
- // some lines can be eliminated: the ones that have a lower offset
- // than some other line with the same slope.
- // That is, for any k1 and k2:
- // if slope[k1] = slope[k2] and offset[k1] > offset[k2],
- // then k2 can be eliminated.
- // (This is actually important to do as it eliminates a bug.)
- // print("discarding: ",4);
-
- int numCandidates = slope.length;
- HashSet<Integer> discardedIndices = new HashSet<Integer>();
- HashMap<Double, Integer> indicesOfSlopes = new HashMap<Double, Integer>();
- // maps slope to index of best candidate that has that slope.
- // ("best" as in the one with the highest offset)
-
- for (int k1 = 0; k1 < numCandidates; ++k1) {
- double currSlope = slope[k1];
- if (!indicesOfSlopes.containsKey(currSlope)) {
- indicesOfSlopes.put(currSlope, k1);
- } else {
- int existingIndex = indicesOfSlopes.get(currSlope);
- if (offset[existingIndex] > offset[k1]) {
- discardedIndices.add(k1);
- // print(k1 + " ",4);
- } else if (offset[k1] > offset[existingIndex]) {
- indicesOfSlopes.put(currSlope, k1);
- discardedIndices.add(existingIndex);
- // print(existingIndex + " ",4);
- }
- }
- }
-
-
- // old way of doing it; takes quadratic time (vs. linear time above)
- /*
- * for (int k1 = 0; k1 < numCandidates; ++k1) { for (int k2 = 0; k2 < numCandidates; ++k2) { if
- * (k1 != k2 && slope[k1] == slope[k2] && offset[k1] > offset[k2]) { discardedIndices.add(k2);
- * // print(k2 + " ",4); } } }
- */
-
- // println("",4);
- return discardedIndices;
- } // indicesToDiscard(double[] slope, double[] offset)
-
- public static void main(String[] args) {
+ public static void main(String[] args) throws FileNotFoundException, IOException {
String configFileName = args[0];
String stateFileName = args[1];
int currIteration = Integer.parseInt(args[2]);
JoshuaConfiguration joshuaConfiguration = new JoshuaConfiguration();
-
+
MertCore DMC = new MertCore(joshuaConfiguration); // dummy MertCore object
// if bad args[], System.exit(80)
@@ -3140,49 +2997,49 @@ public class MertCore {
/*
- *
+ *
* fake: ----- ex2_N300: java -javaagent:shiftone-jrat.jar -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir
* MERT_example -s src.txt -r ref.all -rps 4 -cmd decoder_command_ex2.txt -dcfg config_ex2.txt
* -decOut nbest_ex2.out -N 300 -p params.txt -maxIt 25 -opi 0 -ipi 20 -v 2 -rand 0 -seed
* 1226091488390 -save 1 -fake nbest_ex2.out.N300.it >
* ex2_N300ipi20opi0_300max+defratios.it10.noMemRep.bugFixes.monitored.txt
- *
+ *
* ex2_N500: java -javaagent:shiftone-jrat.jar -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir MERT_example
* -s src.txt -r ref.all -rps 4 -cmd decoder_command_ex2.txt -dcfg config_ex2.txt -decOut
* nbest_ex2.out -N 500 -p params.txt -maxIt 25 -opi 0 -ipi 20 -v 2 -rand 0 -seed 1226091488390
* -save 1 -fake nbest_ex2.out.N500.it >
* ex2_N500ipi20opi0_300max+defratios.it05.noMemRep.bugFixes.monitored.txt
- *
+ *
* exL_N300__600max: java -javaagent:shiftone-jrat.jar -Xmx600m -cp bin joshua.ZMERT.ZMERT -dir
* MERT_example -s mt06_source.txt -r mt06_ref.all -rps 4 -cmd decoder_command_ex2.txt -dcfg
* config_ex2.txt -decOut nbest_exL.out -N 300 -p params.txt -maxIt 5 -opi 0 -ipi 20 -v 2 -rand 0
* -seed 1226091488390 -save 1 -fake nbest_exL.out.it >
* exL_N300ipi20opi0_600max+defratios.it05.noMemRep.bugFixes.monitored.txt
- *
+ *
* exL_N300__300max: java -javaagent:shiftone-jrat.jar -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir
* MERT_example -s mt06_source.txt -r mt06_ref.all -rps 4 -cmd decoder_command_ex2.txt -dcfg
* config_ex2.txt -decOut nbest_exL.out -N 300 -p params.txt -maxIt 5 -opi 0 -ipi 20 -v 2 -rand 0
* -seed 1226091488390 -save 1 -fake nbest_exL.out.it >
* exL_N300ipi20opi0_300max+defratios.it05.noMemRep.bugFixes.monitored.txt
- *
+ *
* gen: ---- ex2_N300: make sure top_n=300 in MERT_example\config_ex2.txt java
* -javaagent:shiftone-jrat.jar -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir MERT_example -s src.txt -r
* ref.all -rps 4 -cmd decoder_command_ex2.txt -dcfg config_ex2.txt -decOut nbest_ex2.out -N 300 -p
* params.txt -maxIt 25 -opi 0 -ipi 20 -v 2 -rand 0 -seed 1226091488390 -save 1 >
* ex2_N300ipi20opi0_300max+defratios.itxx.monitored.txt.gen
- *
+ *
* ex2_N500: make sure top_n=500 in MERT_example\config_ex2.txt java -javaagent:shiftone-jrat.jar
* -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir MERT_example -s src.txt -r ref.all -rps 4 -cmd
* decoder_command_ex2.txt -dcfg config_ex2.txt -decOut nbest_ex2.out -N 500 -p params.txt -maxIt 25
* -opi 0 -ipi 20 -v 2 -rand 0 -seed 1226091488390 -save 1 >
* ex2_N500ipi20opi0_300max+defratios.itxx.monitored.txt.gen
- *
+ *
* exL_N300__600max: run on CLSP machines only! (e.g. z12) $JAVA_bin/java
* -javaagent:shiftone-jrat.jar -Xmx600m -cp bin joshua.ZMERT.ZMERT -dir YOURDIR -s mt06_source.txt
* -r mt06_ref.all -rps 4 -cmd decoder_command.txt -dcfg config_exL.txt -decOut nbest_exL.out -N 300
* -p params.txt -maxIt 25 -opi 0 -ipi 20 -v 2 -rand 0 -seed 1226091488390 -save 1 >
* exL_N300ipi20opi0_600max+defratios.itxx.monitored.txt.gen
- *
+ *
* exL_N300__300max: run on CLSP machines only! (e.g. z12) $JAVA_bin/java
* -javaagent:shiftone-jrat.jar -Xmx300m -cp bin joshua.ZMERT.ZMERT -dir YOURDIR -s mt06_source.txt
* -r mt06_ref.all -rps 4 -cmd decoder_command.txt -dcfg config_exL.txt -decOut nbest_exL.out -N 300
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/test/java/org/apache/joshua/packed/Benchmark.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/packed/Benchmark.java b/src/test/java/org/apache/joshua/packed/Benchmark.java
index 41cf2a0..7c4fc80 100644
--- a/src/test/java/org/apache/joshua/packed/Benchmark.java
+++ b/src/test/java/org/apache/joshua/packed/Benchmark.java
@@ -18,9 +18,6 @@
*/
package org.apache.joshua.packed;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
@@ -30,30 +27,33 @@ import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.util.Random;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
/**
* This program runs a little benchmark to check reading speed on various data
* representations.
- *
+ *
* Usage: java Benchmark PACKED_GRAMMAR_DIR TIMES
*/
-public class Benchmark {
+public class Benchmark implements AutoCloseable{
-
private static final Logger LOG = LoggerFactory.getLogger(Benchmark.class);
private IntBuffer intBuffer;
private MappedByteBuffer byteBuffer;
private int[] intArray;
+ private final FileInputStream fin;
public Benchmark(String dir) throws IOException {
File file = new File(dir + "/slice_00000.source");
-
- FileChannel source_channel = new FileInputStream(file).getChannel();
+ this.fin = new FileInputStream(file);
+ FileChannel source_channel = this.fin.getChannel();
int byte_size = (int) source_channel.size();
int int_size = byte_size / 4;
- byteBuffer = source_channel.map(MapMode.READ_ONLY, 0, byte_size);
+ byteBuffer = source_channel.map(MapMode.READ_ONLY, 0, byte_size);
intBuffer = byteBuffer.asIntBuffer();
intArray = new int[int_size];
@@ -120,7 +120,13 @@ public class Benchmark {
}
public static void main(String args[]) throws IOException {
- Benchmark pr = new Benchmark(args[0]);
- pr.benchmark( Integer.parseInt(args[1]));
+ try (Benchmark pr = new Benchmark(args[0]);) {
+ pr.benchmark( Integer.parseInt(args[1]));
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ this.fin.close();
}
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java b/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
index 8192cb3..8c959c4 100644
--- a/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
+++ b/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
@@ -18,12 +18,15 @@
*/
package org.apache.joshua.system;
+import static org.mockito.Mockito.doReturn;
+import static org.testng.Assert.assertTrue;
+
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
-import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import org.apache.joshua.decoder.Decoder;
@@ -37,9 +40,6 @@ import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
-import static org.mockito.Mockito.doReturn;
-import static org.testng.Assert.assertTrue;
-
/**
* Integration test for multithreaded Joshua decoder tests. Grammar used is a
* toy packed grammar.
@@ -124,8 +124,8 @@ public class MultithreadedTranslationTests {
// engine.
TranslationRequestStream req = new TranslationRequestStream(
new BufferedReader(new InputStreamReader(new ByteArrayInputStream(sb.toString()
- .getBytes(Charset.forName("UTF-8"))))), joshuaConfig);
-
+ .getBytes(StandardCharsets.UTF_8)))), joshuaConfig);
+
ByteArrayOutputStream output = new ByteArrayOutputStream();
// WHEN
[05/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/test/java/org/apache/joshua/packed/Benchmark.java
----------------------------------------------------------------------
diff --cc joshua-core/src/test/java/org/apache/joshua/packed/Benchmark.java
index 41cf2a0,0000000..7c4fc80
mode 100644,000000..100644
--- a/joshua-core/src/test/java/org/apache/joshua/packed/Benchmark.java
+++ b/joshua-core/src/test/java/org/apache/joshua/packed/Benchmark.java
@@@ -1,126 -1,0 +1,132 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.packed;
+
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
-
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.IntBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.channels.FileChannel.MapMode;
+import java.util.Random;
+
++import org.slf4j.Logger;
++import org.slf4j.LoggerFactory;
++
+/**
+ * This program runs a little benchmark to check reading speed on various data
+ * representations.
- *
++ *
+ * Usage: java Benchmark PACKED_GRAMMAR_DIR TIMES
+ */
+
- public class Benchmark {
++public class Benchmark implements AutoCloseable{
+
-
+ private static final Logger LOG = LoggerFactory.getLogger(Benchmark.class);
+
+ private IntBuffer intBuffer;
+ private MappedByteBuffer byteBuffer;
+ private int[] intArray;
++ private final FileInputStream fin;
+
+ public Benchmark(String dir) throws IOException {
+ File file = new File(dir + "/slice_00000.source");
-
- FileChannel source_channel = new FileInputStream(file).getChannel();
++ this.fin = new FileInputStream(file);
++ FileChannel source_channel = this.fin.getChannel();
+ int byte_size = (int) source_channel.size();
+ int int_size = byte_size / 4;
+
- byteBuffer = source_channel.map(MapMode.READ_ONLY, 0, byte_size);
++ byteBuffer = source_channel.map(MapMode.READ_ONLY, 0, byte_size);
+ intBuffer = byteBuffer.asIntBuffer();
+
+ intArray = new int[int_size];
+ intBuffer.get(intArray);
+ }
+
+ public void benchmark(int times) {
+ LOG.info("Beginning benchmark.");
+
+ Random r = new Random();
+ r.setSeed(1234567890);
+ int[] positions = new int[1000];
+ for (int i = 0; i < positions.length; i++)
+ positions[i] = r.nextInt(intArray.length);
+
+ long sum;
+
+ long start_time = System.currentTimeMillis();
+
+ sum = 0;
+ for (int t = 0; t < times; t++)
+ for (int i = 0; i < positions.length; i++)
+ sum += byteBuffer.getInt(positions[i] * 4);
+ LOG.info("Sum: {}", sum);
+ long byte_time = System.currentTimeMillis();
+
+ sum = 0;
+ for (int t = 0; t < times; t++)
+ for (int i = 0; i < positions.length; i++)
+ sum += intBuffer.get(positions[i]);
+ LOG.info("Sum: {}", sum);
+ long int_time = System.currentTimeMillis();
+
+ sum = 0;
+ for (int t = 0; t < times; t++)
+ for (int i = 0; i < positions.length; i++)
+ sum += intArray[positions[i]];
+ LOG.info("Sum: {}", sum);
+ long array_time = System.currentTimeMillis();
+
+ sum = 0;
+ for (int t = 0; t < times; t++)
+ for (int i = 0; i < (intArray.length / 8); i++)
+ sum += intArray[i * 6] + intArray[i * 6 + 2];
+ LOG.info("Sum: {}", sum);
+ long mult_time = System.currentTimeMillis();
+
+ sum = 0;
+ for (int t = 0; t < times; t++) {
+ int index = 0;
+ for (int i = 0; i < (intArray.length / 8); i++) {
+ sum += intArray[index] + intArray[index + 2];
+ index += 6;
+ }
+ }
+ LOG.info("Sum: {}", sum);
+ long add_time = System.currentTimeMillis();
+
+ LOG.info("ByteBuffer: {}", (byte_time - start_time));
+ LOG.info("IntBuffer: {}", (int_time - byte_time));
+ LOG.info("Array: {}", (array_time - int_time));
+ LOG.info("Multiply: {}", (mult_time - array_time));
+ LOG.info("Add: {}", (add_time - mult_time));
+ }
+
+ public static void main(String args[]) throws IOException {
- Benchmark pr = new Benchmark(args[0]);
- pr.benchmark( Integer.parseInt(args[1]));
++ try (Benchmark pr = new Benchmark(args[0]);) {
++ pr.benchmark( Integer.parseInt(args[1]));
++ }
++ }
++
++ @Override
++ public void close() throws IOException {
++ this.fin.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
----------------------------------------------------------------------
diff --cc joshua-core/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
index 10872d0,0000000..01d3963
mode 100644,000000..100644
--- a/joshua-core/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
+++ b/joshua-core/src/test/java/org/apache/joshua/system/MultithreadedTranslationTests.java
@@@ -1,180 -1,0 +1,180 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.system;
+
++import static org.mockito.Mockito.doReturn;
++import static org.testng.Assert.assertTrue;
++
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
- import java.nio.charset.Charset;
++import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.Translation;
+import org.apache.joshua.decoder.TranslationResponseStream;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.mockito.Mockito;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
- import static org.mockito.Mockito.doReturn;
- import static org.testng.Assert.assertTrue;
-
+/**
+ * Integration test for multithreaded Joshua decoder tests. Grammar used is a
+ * toy packed grammar.
+ *
+ * @author Kellen Sunderland kellen.sunderland@gmail.com
+ */
+public class MultithreadedTranslationTests {
+
+ private JoshuaConfiguration joshuaConfig = null;
+ private Decoder decoder = null;
+ private static final String INPUT = "A K B1 U Z1 Z2 B2 C";
+ private static final String EXCEPTION_MESSAGE = "This exception should properly propagate";
+ private int previousLogLevel;
+ private final static long NANO_SECONDS_PER_SECOND = 1_000_000_000;
+
+ @BeforeClass
+ public void setUp() throws Exception {
+ joshuaConfig = new JoshuaConfiguration();
+ joshuaConfig.search_algorithm = "cky";
+ joshuaConfig.mark_oovs = false;
+ joshuaConfig.pop_limit = 100;
+ joshuaConfig.use_unique_nbest = false;
+ joshuaConfig.include_align_index = false;
+ joshuaConfig.topN = 0;
+ joshuaConfig.tms.add("thrax -owner pt -maxspan 20 -path src/test/resources/wa_grammar.packed");
+ joshuaConfig.tms.add("thrax -owner glue -maxspan -1 -path src/test/resources/grammar.glue");
+ joshuaConfig.goal_symbol = "[GOAL]";
+ joshuaConfig.default_non_terminal = "[X]";
+ joshuaConfig.features.add("OOVPenalty");
+ joshuaConfig.weights.add("tm_pt_0 1");
+ joshuaConfig.weights.add("tm_pt_1 1");
+ joshuaConfig.weights.add("tm_pt_2 1");
+ joshuaConfig.weights.add("tm_pt_3 1");
+ joshuaConfig.weights.add("tm_pt_4 1");
+ joshuaConfig.weights.add("tm_pt_5 1");
+ joshuaConfig.weights.add("tm_glue_0 1");
+ joshuaConfig.weights.add("OOVPenalty 2");
+ joshuaConfig.num_parallel_decoders = 500; // This will enable 500 parallel
+ // decoders to run at once.
+ // Useful to help flush out
+ // concurrency errors in
+ // underlying
+ // data-structures.
+ this.decoder = new Decoder(joshuaConfig);
+ previousLogLevel = Decoder.VERBOSE;
+ Decoder.VERBOSE = 0;
+ }
+
+ @AfterClass
+ public void tearDown() throws Exception {
+ this.decoder.cleanUp();
+ this.decoder = null;
+ Decoder.VERBOSE = previousLogLevel;
+ }
+
+
+
+ // This test was created specifically to reproduce a multithreaded issue
+ // related to mapped byte array access in the PackedGrammer getAlignmentArray
+ // function.
+
+ // We'll test the decoding engine using N = 10,000 identical inputs. This
+ // should be sufficient to induce concurrent data access for many shared
+ // data structures.
+
+ @Test()
+ public void givenPackedGrammar_whenNTranslationsCalledConcurrently_thenReturnNResults() throws IOException {
+ // GIVEN
+
+ int inputLines = 10000;
+ joshuaConfig.use_structured_output = true; // Enabled alignments.
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < inputLines; i++) {
+ sb.append(INPUT + "\n");
+ }
+
+ // Append a large string together to simulate N requests to the decoding
+ // engine.
+ TranslationRequestStream req = new TranslationRequestStream(
+ new BufferedReader(new InputStreamReader(new ByteArrayInputStream(sb.toString()
- .getBytes(Charset.forName("UTF-8"))))), joshuaConfig);
-
++ .getBytes(StandardCharsets.UTF_8)))), joshuaConfig);
++
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+
+ // WHEN
+ // Translate all segments in parallel.
+ TranslationResponseStream translationResponseStream = this.decoder.decodeAll(req);
+
+ ArrayList<Translation> translationResults = new ArrayList<Translation>();
+
+
+ final long translationStartTime = System.nanoTime();
+ try {
+ for (Translation t: translationResponseStream)
+ translationResults.add(t);
+ } finally {
+ if (output != null) {
+ try {
+ output.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ final long translationEndTime = System.nanoTime();
+ final double pipelineLoadDurationInSeconds = (translationEndTime - translationStartTime)
+ / ((double)NANO_SECONDS_PER_SECOND);
+ System.err.println(String.format("%.2f seconds", pipelineLoadDurationInSeconds));
+
+ // THEN
+ assertTrue(translationResults.size() == inputLines);
+ }
+
+ @Test(expectedExceptions = RuntimeException.class,
+ expectedExceptionsMessageRegExp = EXCEPTION_MESSAGE)
+ public void givenDecodeAllCalled_whenRuntimeExceptionThrown_thenPropagate() throws IOException {
+ // GIVEN
+ // A spy request stream that will cause an exception to be thrown on a threadpool thread
+ TranslationRequestStream spyReq = Mockito.spy(new TranslationRequestStream(null, joshuaConfig));
+ doReturn(createSentenceSpyWithRuntimeExceptions()).when(spyReq).next();
+
+ // WHEN
+ // Translate all segments in parallel.
+ TranslationResponseStream translationResponseStream = this.decoder.decodeAll(spyReq);
+
+ ArrayList<Translation> translationResults = new ArrayList<>();
+ for (Translation t: translationResponseStream)
+ translationResults.add(t);
+ }
+
+ private Sentence createSentenceSpyWithRuntimeExceptions() {
+ Sentence sent = new Sentence(INPUT, 0, joshuaConfig);
+ Sentence spy = Mockito.spy(sent);
+ Mockito.when(spy.target()).thenThrow(new RuntimeException(EXCEPTION_MESSAGE));
+ return spy;
+ }
+}
[03/17] incubator-joshua git commit: Fix a number of issues: - Reader
now implements autocloseable - Close various leaks from LineReader -
LineReader no longer implements custom finalize(). Resources should be
explicitly closed when no longer needed. T
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java b/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
index f8173b8..0bda3e6 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
@@ -21,27 +21,27 @@ package org.apache.joshua.decoder.ff.tm.packed;
/***
* This package implements Joshua's packed grammar structure, which enables the efficient loading
* and accessing of grammars. It is described in the paper:
- *
+ *
* @article{ganitkevitch2012joshua,
* Author = {Ganitkevitch, J. and Cao, Y. and Weese, J. and Post, M. and Callison-Burch, C.},
* Journal = {Proceedings of WMT12},
* Title = {Joshua 4.0: Packing, PRO, and paraphrases},
* Year = {2012}}
- *
+ *
* The packed grammar works by compiling out the grammar tries into a compact format that is loaded
* and parsed directly from Java arrays. A fundamental problem is that Java arrays are indexed
* by ints and not longs, meaning the maximum size of the packed grammar is about 2 GB. This forces
* the use of packed grammar slices, which together constitute the grammar. The figure in the
- * paper above shows what each slice looks like.
- *
+ * paper above shows what each slice looks like.
+ *
* The division across slices is done in a depth-first manner. Consider the entire grammar organized
* into a single source-side trie. The splits across tries are done by grouping the root-level
- * outgoing trie arcs --- and the entire trie beneath them --- across slices.
- *
- * This presents a problem: if the subtree rooted beneath a single top-level arc is too big for a
+ * outgoing trie arcs --- and the entire trie beneath them --- across slices.
+ *
+ * This presents a problem: if the subtree rooted beneath a single top-level arc is too big for a
* slice, the grammar can't be packed. This happens with very large Hiero grammars, for example,
* where there are a *lot* of rules that start with [X].
- *
+ *
* A solution being worked on is to split that symbol and pack them into separate grammars with a
* shared vocabulary, and then rely on Joshua's ability to query multiple grammars for rules to
* solve this problem. This is not currently implemented but could be done directly in the
@@ -63,7 +63,6 @@ import java.io.InputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
-import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Files;
@@ -73,7 +72,6 @@ import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -94,15 +92,14 @@ import org.apache.joshua.util.FormatUtils;
import org.apache.joshua.util.encoding.EncoderConfiguration;
import org.apache.joshua.util.encoding.FloatEncoder;
import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
public class PackedGrammar extends AbstractGrammar {
private static final Logger LOG = LoggerFactory.getLogger(PackedGrammar.class);
@@ -135,14 +132,14 @@ public class PackedGrammar extends AbstractGrammar {
if (!Vocabulary.read(vocabFile)) {
throw new RuntimeException("mismatches or collisions while reading on-disk vocabulary");
}
-
+
// Read the config
String configFile = grammar_dir + File.separator + "config";
if (new File(configFile).exists()) {
LOG.info("Reading packed config: {}", configFile);
readConfig(configFile);
}
-
+
// Read the quantizer setup.
LOG.info("Reading encoder configuration: {}{}encoding", grammar_dir, File.separator);
encoding = new EncoderConfiguration();
@@ -226,14 +223,14 @@ public class PackedGrammar extends AbstractGrammar {
* that represent the subtrie for a particular firstword.
* If the GrammarPacker has to distribute rules for a
* source-side firstword across multiple slices, a
- * SliceAggregatingTrie node is created that aggregates those
+ * SliceAggregatingTrie node is created that aggregates those
* tries to hide
* this additional complexity from the grammar interface
* This feature allows packing of grammars where the list of rules
* for a single source-side firstword would exceed the maximum array
* size of Java (2gb).
*/
- public final class PackedRoot implements Trie {
+ public static final class PackedRoot implements Trie {
private final HashMap<Integer, Trie> lookup;
@@ -241,9 +238,9 @@ public class PackedGrammar extends AbstractGrammar {
final Map<Integer, List<Trie>> childTries = collectChildTries(slices);
lookup = buildLookupTable(childTries);
}
-
+
/**
- * Determines whether trie nodes for source first-words are spread over
+ * Determines whether trie nodes for source first-words are spread over
* multiple packedSlices by counting their occurrences.
* @param slices
* @return A mapping from first word ids to a list of trie nodes.
@@ -251,12 +248,12 @@ public class PackedGrammar extends AbstractGrammar {
private Map<Integer, List<Trie>> collectChildTries(final List<PackedSlice> slices) {
final Map<Integer, List<Trie>> childTries = new HashMap<>();
for (PackedSlice packedSlice : slices) {
-
+
// number of tries stored in this packedSlice
final int num_children = packedSlice.source[0];
for (int i = 0; i < num_children; i++) {
final int id = packedSlice.source[2 * i + 1];
-
+
/* aggregate tries with same root id
* obtain a Trie node, already at the correct address in the packedSlice.
* In other words, the lookup index already points to the correct trie node in the packedSlice.
@@ -271,7 +268,7 @@ public class PackedGrammar extends AbstractGrammar {
}
return childTries;
}
-
+
/**
* Build a lookup table for children tries.
* If the list contains only a single child node, a regular trie node
@@ -497,7 +494,7 @@ public class PackedGrammar extends AbstractGrammar {
}
featurePosition += EncoderConfiguration.ID_SIZE + encoder.size();
}
-
+
return featureVector;
}
@@ -511,7 +508,7 @@ public class PackedGrammar extends AbstractGrammar {
if (alignments == null)
throw new RuntimeException("No alignments available.");
int alignment_position = getIntFromByteBuffer(block_id, alignments);
- int num_points = (int) alignments.get(alignment_position);
+ int num_points = alignments.get(alignment_position);
byte[] alignment = new byte[num_points * 2];
alignments.position(alignment_position + 1);
@@ -533,6 +530,7 @@ public class PackedGrammar extends AbstractGrammar {
return getTrie(0);
}
+ @Override
public String toString() {
return name;
}
@@ -540,9 +538,9 @@ public class PackedGrammar extends AbstractGrammar {
/**
* A trie node within the grammar slice. Identified by its position within the source array,
* and, as a supplement, the source string leading from the trie root to the node.
- *
+ *
* @author jg
- *
+ *
*/
public class PackedTrie implements Trie, RuleCollection {
@@ -785,14 +783,14 @@ public class PackedGrammar extends AbstractGrammar {
throw new UnsupportedOperationException();
}
}
-
+
/**
* A packed phrase pair represents a rule of the form of a phrase pair, packed with the
* grammar-packer.pl script, which simply adds a nonterminal [X] to the left-hand side of
* all phrase pairs (and converts the Moses features). The packer then packs these. We have
* to then put a nonterminal on the source and target sides to treat the phrase pairs like
- * left-branching rules, which is how Joshua deals with phrase decoding.
- *
+ * left-branching rules, which is how Joshua deals with phrase decoding.
+ *
* @author Matt Post post@cs.jhu.edu
*
*/
@@ -845,17 +843,17 @@ public class PackedGrammar extends AbstractGrammar {
/**
* Take the English phrase of the underlying rule and prepend an [X].
- *
+ *
* @return the augmented phrase
*/
@Override
public int[] getEnglish() {
return this.englishSupplier.get();
}
-
+
/**
* Take the French phrase of the underlying rule and prepend an [X].
- *
+ *
* @return the augmented French phrase
*/
@Override
@@ -866,10 +864,10 @@ public class PackedGrammar extends AbstractGrammar {
System.arraycopy(src, 0, phrase, 1, src.length);
return phrase;
}
-
+
/**
* Similarly the alignment array needs to be shifted over by one.
- *
+ *
* @return the byte[] alignment
*/
@Override
@@ -967,12 +965,12 @@ public class PackedGrammar extends AbstractGrammar {
public FeatureVector getFeatureVector() {
return this.featureVectorSupplier.get();
}
-
+
@Override
public byte[] getAlignment() {
return this.alignmentsSupplier.get();
}
-
+
@Override
public String getAlignmentString() {
throw new RuntimeException("AlignmentString not implemented for PackedRule!");
@@ -1018,23 +1016,23 @@ public class PackedGrammar extends AbstractGrammar {
public void addOOVRules(int word, List<FeatureFunction> featureFunctions) {
throw new RuntimeException("PackedGrammar.addOOVRules(): I can't add OOV rules");
}
-
+
@Override
public void addRule(Rule rule) {
throw new RuntimeException("PackedGrammar.addRule(): I can't add rules");
}
-
- /**
+
+ /**
* Read the config file
- *
+ *
* TODO: this should be rewritten using typeconfig.
- *
+ *
* @param config
* @throws IOException
*/
private void readConfig(String config) throws IOException {
int version = 0;
-
+
for (String line: new LineReader(config)) {
String[] tokens = line.split(" = ");
if (tokens[0].equals("max-source-len"))
@@ -1043,7 +1041,7 @@ public class PackedGrammar extends AbstractGrammar {
version = Integer.parseInt(tokens[1]);
}
}
-
+
if (version != 3) {
String message = String.format("The grammar at %s was packed with packer version %d, but the earliest supported version is %d",
this.grammarDir, version, SUPPORTED_VERSION);
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/lattice/Lattice.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/lattice/Lattice.java b/src/main/java/org/apache/joshua/lattice/Lattice.java
index 2332159..c557c07 100644
--- a/src/main/java/org/apache/joshua/lattice/Lattice.java
+++ b/src/main/java/org/apache/joshua/lattice/Lattice.java
@@ -24,7 +24,6 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -361,6 +360,7 @@ public class Lattice<Value> implements Iterable<Node<Value>> {
*
* @return An iterator over the nodes in this lattice.
*/
+ @Override
public Iterator<Node<Value>> iterator() {
return nodes.iterator();
}
@@ -471,56 +471,10 @@ public class Lattice<Value> implements Iterable<Node<Value>> {
}
/**
- * Topologically sorts the nodes and reassigns their numbers. Assumes that the first node is the
- * source, but otherwise assumes nothing about the input.
- *
- * Probably correct, but untested.
- */
- @SuppressWarnings("unused")
- private void topologicalSort() {
- HashMap<Node<Value>, List<Arc<Value>>> outgraph = new HashMap<Node<Value>, List<Arc<Value>>>();
- HashMap<Node<Value>, List<Arc<Value>>> ingraph = new HashMap<Node<Value>, List<Arc<Value>>>();
- for (Node<Value> node: nodes) {
- ArrayList<Arc<Value>> arcs = new ArrayList<Arc<Value>>();
- for (Arc<Value> arc: node.getOutgoingArcs()) {
- arcs.add(arc);
-
- if (! ingraph.containsKey(arc.getHead()))
- ingraph.put(arc.getHead(), new ArrayList<Arc<Value>>());
- ingraph.get(arc.getHead()).add(arc);
-
- outgraph.put(node, arcs);
- }
- }
-
- ArrayList<Node<Value>> sortedNodes = new ArrayList<Node<Value>>();
- Stack<Node<Value>> stack = new Stack<Node<Value>>();
- stack.push(nodes.get(0));
-
- while (! stack.empty()) {
- Node<Value> node = stack.pop();
- sortedNodes.add(node);
- for (Arc<Value> arc: outgraph.get(node)) {
- outgraph.get(node).remove(arc);
- ingraph.get(arc.getHead()).remove(arc);
-
- if (ingraph.get(arc.getHead()).size() == 0)
- sortedNodes.add(arc.getHead());
- }
- }
-
- int id = 0;
- for (Node<Value> node : sortedNodes)
- node.setID(id++);
-
- this.nodes = sortedNodes;
- }
-
- /**
- * Constructs a lattice from a given string representation.
+ * Constructs a lattice from a given string representation.
*
- * @param data String representation of a lattice.
- * @return A lattice that corresponds to the given string.
+ * @param data String representation of a lattice.
+ * @return A lattice that corresponds to the given string.
*/
public static Lattice<String> createFromString(String data) {
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/metrics/CHRF.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/metrics/CHRF.java b/src/main/java/org/apache/joshua/metrics/CHRF.java
index d67f6e0..dcf606a 100644
--- a/src/main/java/org/apache/joshua/metrics/CHRF.java
+++ b/src/main/java/org/apache/joshua/metrics/CHRF.java
@@ -27,18 +27,18 @@ import java.util.logging.Logger;
* - Adapted to extend Joshua's EvaluationMetric class
* - Use of a length penalty to prevent chrF to prefer too long (with beta %gt; 1) or too short (with beta < 1) translations
* - Use of hash tables for efficient n-gram matching
- *
+ *
* The metric has 2 parameters:
* - Beta. It assigns beta times more weight to recall than to precision. By default 1.
* Although for evaluation the best correlation was found with beta=3, we've found the
* best results for tuning so far with beta=1
* - Max-ngram. Maximum n-gram length (characters). By default 6.
- *
+ *
* If you use this metric in your research please cite [2].
- *
+ *
* [1] Maja Popovic. 2015. chrF: character n-gram F-score for automatic MT evaluation.
* In Proceedings of the Tenth Workshop on Statistical Machine Translation. Lisbon, Portugal, pages 392\u2013395.
- * [2] V�ctor S�nchez Cartagena and Antonio Toral. 2016.
+ * [2] V�ctor S�nchez Cartagena and Antonio Toral. 2016.
* Abu-MaTran at WMT 2016 Translation Task: Deep Learning, Morphological Segmentation and Tuning on Character Sequences.
* In Proceedings of the First Conference on Machine Translation (WMT16). Berlin, Germany.
@@ -51,17 +51,17 @@ public class CHRF extends EvaluationMetric {
protected double factor;
protected int maxGramLength = 6; // The maximum n-gram we care about
//private double[] nGramWeights; //TODO to weight them differently
-
+
//private String metricName;
//private boolean toBeMinimized;
//private int suffStatsCount;
-
+
public CHRF()
{
this(1, 6);
}
-
+
public CHRF(String[] CHRF_options)
{
//
@@ -71,44 +71,47 @@ public class CHRF extends EvaluationMetric {
//
this(Double.parseDouble(CHRF_options[0]), Integer.parseInt(CHRF_options[1]));
}
-
- public CHRF(double bt, int mxGrmLn){
+
+ public CHRF(double bt, int mxGrmLn){
if (bt > 0) {
beta = bt;
} else {
logger.severe("Beta must be positive");
System.exit(1);
}
-
+
if (mxGrmLn >= 1) {
maxGramLength = mxGrmLn;
} else {
logger.severe("Maximum gram length must be positive");
System.exit(1);
}
-
+
initialize(); // set the data members of the metric
}
+ @Override
protected void initialize()
{
metricName = "CHRF";
toBeMinimized = false;
suffStatsCount = 4 * maxGramLength;
- factor = Math.pow(beta, 2);
+ factor = Math.pow(beta, 2);
}
-
+
+ @Override
public double bestPossibleScore() { return 100.0; }
-
+
+ @Override
public double worstPossibleScore() { return 0.0; }
protected String separateCharacters(String s)
{
- String s_chars = "";
+ String s_chars = "";
//alternative implementation (the one below seems more robust)
/*for (int i = 0; i < s.length(); i++) {
if (s.charAt(i) == ' ') continue;
- s_chars += s.charAt(i) + " ";
+ s_chars += s.charAt(i) + " ";
}
System.out.println("CHRF separate chars1: " + s_chars);*/
@@ -122,7 +125,7 @@ public class CHRF extends EvaluationMetric {
return s_chars;
}
-
+
protected HashMap<String, Integer>[] getGrams(String s)
{
HashMap<String, Integer>[] grams = new HashMap[1 + maxGramLength];
@@ -131,7 +134,7 @@ public class CHRF extends EvaluationMetric {
grams[n] = new HashMap<String, Integer>();
}
-
+
for (int n=1; n<=maxGramLength; n++){
String gram = "";
for (int i = 0; i < s.length() - n + 1; i++){
@@ -143,7 +146,7 @@ public class CHRF extends EvaluationMetric {
grams[n].put(gram, 1);
}
}
-
+
}
/* debugging
@@ -153,25 +156,22 @@ public class CHRF extends EvaluationMetric {
for (String gram: grams[n].keySet()){
key = gram.toString();
value = grams[n].get(gram).toString();
- System.out.println(key + " " + value);
+ System.out.println(key + " " + value);
}
}*/
-
+
return grams;
}
-
+
protected int[] candRefErrors(HashMap<String, Integer> ref, HashMap<String, Integer> cand)
{
int[] to_return = {0,0};
String gram;
- int cand_grams = 0, ref_grams = 0;
+ int cand_grams = 0;
int candGramCount = 0, refGramCount = 0;
int errors = 0;
- double result = 0;
- String not_found = "";
-
-
+
Iterator<String> it = (cand.keySet()).iterator();
while (it.hasNext()) {
@@ -180,41 +180,36 @@ public class CHRF extends EvaluationMetric {
cand_grams += candGramCount;
if (ref.containsKey(gram)) {
refGramCount = ref.get(gram);
- ref_grams += refGramCount;
if(candGramCount>refGramCount){
int error_here = candGramCount - refGramCount;
errors += error_here;
- not_found += gram + " (" + error_here + " times) ";
}
} else {
refGramCount = 0;
errors += candGramCount;
- not_found += gram + " ";
- }
+ }
}
-
+
//System.out.println(" Ngrams not found: " + not_found);
-
+
to_return[0] = cand_grams;
to_return[1] = errors;
-
+
return to_return;
}
-
+
+ @Override
public int[] suffStats(String cand_str, int i) //throws Exception
{
int[] stats = new int[suffStatsCount];
- double[] precisions = new double[maxGramLength+1];
- double[] recalls = new double[maxGramLength+1];
-
//TODO check unicode chars correctly split
String cand_char = separateCharacters(cand_str);
String ref_char = separateCharacters(refSentences[i][0]);
-
+
HashMap<String, Integer>[] grams_cand = getGrams(cand_char);
HashMap<String, Integer>[] grams_ref = getGrams(ref_char);
-
+
for (int n = 1; n <= maxGramLength; ++n) {
//System.out.println("Calculating precision...");
int[] precision_vals = candRefErrors(grams_ref[n], grams_cand[n]);
@@ -222,7 +217,7 @@ public class CHRF extends EvaluationMetric {
//System.out.println("Calculating recall...");
int[] recall_vals = candRefErrors(grams_cand[n], grams_ref[n]);
//System.out.println(" length: " + recall_vals[0] + ", errors: " + recall_vals[1]);
-
+
stats[4*(n-1)] = precision_vals[0]; //cand_grams
stats[4*(n-1)+1] = precision_vals[1]; //errors (precision)
stats[4*(n-1)+2] = recall_vals[0]; //ref_grams
@@ -233,6 +228,7 @@ public class CHRF extends EvaluationMetric {
}
+ @Override
public double score(int[] stats)
{
int precision_ngrams, recall_ngrams, precision_errors, recall_errors;
@@ -240,9 +236,9 @@ public class CHRF extends EvaluationMetric {
double[] recalls = new double[maxGramLength+1];
double[] fs = new double[maxGramLength+1];
//double[] scs = new double[maxGramLength+1];
- double totalPrecision = 0, totalRecall = 0, totalF = 0, totalSC = 0;
+ double totalF = 0, totalSC = 0;
double lp = 1;
-
+
if (stats.length != suffStatsCount) {
System.out.println("Mismatch between stats.length and suffStatsCount (" + stats.length + " vs. " + suffStatsCount + ") in NewMetric.score(int[])");
System.exit(1);
@@ -257,42 +253,41 @@ public class CHRF extends EvaluationMetric {
if (precision_ngrams != 0)
precisions[n] = 100 - 100*precision_errors/ (double)precision_ngrams;
else precisions[n] = 0;
-
+
if (recall_ngrams != 0)
recalls[n] = 100 - 100*recall_errors/ (double)recall_ngrams;
else
recalls[n] = 0;
-
+
if(precisions[n] != 0 || recalls[n] != 0)
fs[n] = (1+factor) * recalls[n] * precisions[n] / (factor * precisions[n] + recalls[n]);
else
fs[n] = 0;
-
+
//System.out.println("Precision (n=" + n + "): " + precisions[n]);
//System.out.println("Recall (n=" + n + "): " + recalls[n]);
//System.out.println("F (n=" + n + "): " + fs[n]);
- totalPrecision += (1/(double)maxGramLength) * precisions[n];
- totalRecall += (1/(double)maxGramLength) * recalls[n];
totalF += (1/(double)maxGramLength) * fs[n];
}
//length penalty
- if (beta>1){ //penalise long translations
+ if (beta>1){ //penalise long translations
lp = Math.min(1, stats[2]/(double)stats[0]);
} else if (beta < 1){ //penalise short translations
lp = Math.min(1, stats[0]/(double)stats[2]);
}
totalSC = totalF*lp;
-
+
//System.out.println("Precision (total): " + totalPrecision);
//System.out.println("Recall (total):" + totalRecall);
//System.out.println("F (total): " + totalF);
-
+
return totalSC;
}
+ @Override
public void printDetailedScore_fromStats(int[] stats, boolean oneLiner)
{
System.out.println(metricName + " = " + score(stats));
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/metrics/SARI.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/metrics/SARI.java b/src/main/java/org/apache/joshua/metrics/SARI.java
index 129e4af..9ee3af3 100644
--- a/src/main/java/org/apache/joshua/metrics/SARI.java
+++ b/src/main/java/org/apache/joshua/metrics/SARI.java
@@ -14,30 +14,30 @@
*/
package org.apache.joshua.metrics;
-// Changed PROCore.java (text normalization function) and EvaluationMetric too
-
-import java.util.Map;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.logging.Logger;
-
import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStreamReader;
import java.io.File;
import java.io.FileInputStream;
+import java.io.IOException;
import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.HashMap;
+import java.util.Iterator;
+
+// Changed PROCore.java (text normalization function) and EvaluationMetric too
+
+import java.util.Map;
+import java.util.logging.Logger;
/***
* Implementation of the SARI metric for text-to-text correction.
- *
+ *
* \@article{xu2016optimizing,
* title={Optimizing statistical machine translation for text simplification},
* author={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris},
* journal={Transactions of the Association for Computational Linguistics},
* volume={4},
* year={2016}}
- *
+ *
* @author Wei Xu
*/
public class SARI extends EvaluationMetric {
@@ -77,6 +77,7 @@ public class SARI extends EvaluationMetric {
}
+ @Override
protected void initialize() {
metricName = "SARI";
toBeMinimized = false;
@@ -88,10 +89,12 @@ public class SARI extends EvaluationMetric {
}
+ @Override
public double bestPossibleScore() {
return 1.0;
}
+ @Override
public double worstPossibleScore() {
return 0.0;
}
@@ -165,6 +168,7 @@ public class SARI extends EvaluationMetric {
}
// set contents of stats[] here!
+ @Override
public int[] suffStats(String cand_str, int i) {
int[] stats = new int[suffStatsCount];
@@ -173,9 +177,9 @@ public class SARI extends EvaluationMetric {
for (int n = 1; n <= maxGramLength; ++n) {
// ADD OPERATIONS
- HashMap cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
- HashMap cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
- HashMap ref_sub_src = substractHashMap(refNgramCounts[i][n], srcNgramCounts[i][n]);
+ HashMap<String, Integer> cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
+ HashMap<String, Integer> cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
+ HashMap<String, Integer> ref_sub_src = substractHashMap(refNgramCounts[i][n], srcNgramCounts[i][n]);
stats[StatIndex.values().length * (n - 1)
+ StatIndex.ADDBOTH.ordinal()] = cand_and_ref_sub_src.keySet().size();
@@ -190,11 +194,11 @@ public class SARI extends EvaluationMetric {
// System.out.println("ref_sub_src" + ref_sub_src + ref_sub_src.keySet().size());
// DELETION OPERATIONS
- HashMap src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
+ HashMap<String, Integer> src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
+ refsPerSen, refsPerSen);
+ HashMap<String, Integer> src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
+ refsPerSen, 1);
+ HashMap<String, Integer> src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
stats[StatIndex.values().length * (n - 1) + StatIndex.DELBOTH.ordinal()] = sumHashMapByValues(
src_sub_cand_sub_ref);
@@ -209,14 +213,14 @@ public class SARI extends EvaluationMetric {
// System.out.println("src_sub_ref" + src_sub_ref + sumHashMapByValues(src_sub_ref));
stats[StatIndex.values().length * (n - 1) + StatIndex.DELREF.ordinal()] = src_sub_ref.keySet()
- .size() * this.refsPerSen;
+ .size() * refsPerSen;
// KEEP OPERATIONS
- HashMap src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
+ HashMap<String, Integer> src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
+ refsPerSen, refsPerSen);
+ HashMap<String, Integer> src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
+ refsPerSen, 1);
+ HashMap<String, Integer> src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
stats[StatIndex.values().length * (n - 1)
+ StatIndex.KEEPBOTH.ordinal()] = sumHashMapByValues(src_and_cand_and_ref);
@@ -232,55 +236,11 @@ public class SARI extends EvaluationMetric {
divideHashMap(src_and_cand_and_ref, src_and_ref));
stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPREF.ordinal()] = src_and_ref
.keySet().size();
-
- // System.out.println("src_and_cand_and_ref" + src_and_cand_and_ref);
- // System.out.println("src_and_cand" + src_and_cand);
- // System.out.println("src_and_ref" + src_and_ref);
-
- // stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPBOTH2.ordinal()] = (int)
- // sumHashMapByDoubleValues(divideHashMap(src_and_cand_and_ref,src_and_ref)) * 100000000 /
- // src_and_ref.keySet().size() ;
- // stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPREF.ordinal()] =
- // src_and_ref.keySet().size() * 8;
-
- // System.out.println("src_and_cand_and_ref" + src_and_cand_and_ref);
- // System.out.println("src_and_cand" + src_and_cand);
- // System.out.println("divide" + divideHashMap(src_and_cand_and_ref,src_and_cand));
- // System.out.println(sumHashMapByDoubleValues(divideHashMap(src_and_cand_and_ref,src_and_cand)));
-
}
-
- int n = 1;
-
- // System.out.println("CAND: " + candNgramCounts[n]);
- // System.out.println("SRC: " + srcNgramCounts[i][n]);
- // System.out.println("REF: " + refNgramCounts[i][n]);
-
- HashMap src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
- // System.out.println("SRC&CAND&REF : " + src_and_cand_and_ref);
-
- HashMap cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
- HashMap cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
- // System.out.println("CAND&REF-SRC : " + cand_and_ref_sub_src);
-
- HashMap src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
- // System.out.println("SRC-REF-CAND : " + src_sub_cand_sub_ref);
-
- // System.out.println("DEBUG:" + Arrays.toString(stats));
- // System.out.println("REF-SRC: " + substractHashMap(refNgramCounts[i], srcNgramCounts[i][0],
- // (double)refsPerSen));
-
return stats;
}
+ @Override
public double score(int[] stats) {
if (stats.length != suffStatsCount) {
System.out.println("Mismatch between stats.length and suffStatsCount (" + stats.length
@@ -320,23 +280,13 @@ public class SARI extends EvaluationMetric {
+ StatIndex.DELBOTH.ordinal()];
int delCandTotalNgram = stats[StatIndex.values().length * (n - 1)
+ StatIndex.DELCAND.ordinal()];
- int delRefTotalNgram = stats[StatIndex.values().length * (n - 1)
- + StatIndex.DELREF.ordinal()];
double prec_del_n = 0.0;
if (delCandTotalNgram > 0) {
prec_del_n = delCandCorrectNgram / (double) delCandTotalNgram;
}
- double recall_del_n = 0.0;
- if (delRefTotalNgram > 0) {
- recall_del_n = delCandCorrectNgram / (double) delRefTotalNgram;
- }
-
// System.out.println("\nDEBUG-SARI:" + delCandCorrectNgram + " " + delRefTotalNgram);
-
- double f1_del_n = meanHarmonic(prec_del_n, recall_del_n);
-
// sc += weights[n] * f1_del_n;
sc += weights[n] * prec_del_n;
@@ -410,7 +360,7 @@ public class SARI extends EvaluationMetric {
double sumcounts = 0;
for (Map.Entry<String, Double> e : counter.entrySet()) {
- sumcounts += (double) e.getValue();
+ sumcounts += e.getValue();
}
return sumcounts;
@@ -420,7 +370,7 @@ public class SARI extends EvaluationMetric {
int sumcounts = 0;
for (Map.Entry<String, Integer> e : counter.entrySet()) {
- sumcounts += (int) e.getValue();
+ sumcounts += e.getValue();
}
return sumcounts;
@@ -432,7 +382,6 @@ public class SARI extends EvaluationMetric {
for (Map.Entry<String, Integer> e : counter1.entrySet()) {
String ngram = e.getKey();
- int count1 = e.getValue();
int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
if (count2 == 0) {
newcounter.put(ngram, 1);
@@ -482,7 +431,6 @@ public class SARI extends EvaluationMetric {
for (Map.Entry<String, Integer> e : counter1.entrySet()) {
String ngram = e.getKey();
- int count1 = e.getValue();
int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
if (count2 > 0) {
newcounter.put(ngram, 1);
@@ -661,6 +609,7 @@ public class SARI extends EvaluationMetric {
}
+ @Override
public void printDetailedScore_fromStats(int[] stats, boolean oneLiner) {
System.out.println(metricName + " = " + score(stats));
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/mira/MIRACore.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/mira/MIRACore.java b/src/main/java/org/apache/joshua/mira/MIRACore.java
index 42dd995..a4a6b84 100755
--- a/src/main/java/org/apache/joshua/mira/MIRACore.java
+++ b/src/main/java/org/apache/joshua/mira/MIRACore.java
@@ -44,11 +44,12 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
+import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.Decoder;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.metrics.EvaluationMetric;
import org.apache.joshua.util.StreamGobbler;
-import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -64,19 +65,15 @@ public class MIRACore {
private TreeSet<Integer>[] indicesOfInterest_all;
private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
private final static double NegInf = (-1.0 / 0.0);
private final static double PosInf = (+1.0 / 0.0);
private final static double epsilon = 1.0 / 1000000;
- private int progress;
-
private int verbosity; // anything of priority <= verbosity will be printed
// (lower value for priority means more important)
private Random randGen;
- private int generatedRands;
private int numSentences;
// number of sentences in the dev set
@@ -235,7 +232,7 @@ public class MIRACore {
private boolean usePseudoBleu = true; // need to use pseudo corpus to compute bleu?
private boolean returnBest = false; // return the best weight during tuning
private boolean needScale = true; // need scaling?
- private String trainingMode;
+
private int oraSelectMode = 1;
private int predSelectMode = 1;
private int miraIter = 1;
@@ -263,28 +260,27 @@ public class MIRACore {
this.joshuaConfiguration = joshuaConfiguration;
}
- public MIRACore(String[] args, JoshuaConfiguration joshuaConfiguration) {
+ public MIRACore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(args);
initialize(0);
}
- public MIRACore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
+ public MIRACore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(cfgFileToArgsArray(configFileName));
initialize(0);
}
- private void initialize(int randsToSkip) {
+ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
randGen = new Random(seed);
for (int r = 1; r <= randsToSkip; ++r) {
randGen.nextDouble();
}
- generatedRands = randsToSkip;
if (randsToSkip == 0) {
println("----------------------------------------------------", 1);
@@ -298,7 +294,7 @@ public class MIRACore {
// count the total num of sentences to be decoded, reffilename is the combined reference file
// name(auto generated)
- numSentences = countLines(refFileName) / refsPerSen;
+ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
// ??
processDocInfo();
@@ -311,7 +307,7 @@ public class MIRACore {
set_docSubsetInfo(docSubsetInfo);
// count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
+ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
numParamsOld = numParams;
// read parameter config file
@@ -862,7 +858,6 @@ public class MIRACore {
// iterations if the user specifies a value for prevMERTIterations
// that causes MERT to skip candidates from early iterations.
- double[] currFeatVal = new double[1 + numParams];
String[] featVal_str;
int totalCandidateCount = 0;
@@ -1105,7 +1100,6 @@ public class MIRACore {
for (String featurePair : featVal_str) {
String[] pair = featurePair.split("=");
String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
int featId = Vocabulary.id(name);
// need to identify newly fired feats here
@@ -1529,7 +1523,7 @@ public class MIRACore {
/*
* line format:
- *
+ *
* i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
* .*
*/
@@ -1599,8 +1593,6 @@ public class MIRACore {
BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
PrintWriter outFile = new PrintWriter(cfgFileName);
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
int origFeatNum = 0; // feat num in the template file
String line = inFile.readLine();
@@ -1803,7 +1795,7 @@ public class MIRACore {
// belongs to,
// and its order in that document. (can also use '-' instead of '_')
- int docInfoSize = countNonEmptyLines(docInfoFileName);
+ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
if (docInfoSize < numSentences) { // format #1 or #2
numDocuments = docInfoSize;
@@ -1913,13 +1905,13 @@ public class MIRACore {
/*
* InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
* new BufferedReader(new InputStreamReader(inStream, "utf8"));
- *
+ *
* FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
* outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
* BufferedWriter(outStreamWriter);
- *
+ *
* String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
- *
+ *
* inFile.close(); outFile.close();
*/
return true;
@@ -2017,10 +2009,10 @@ public class MIRACore {
/*
* OBSOLETE MODIFICATION //SPECIAL HANDLING FOR MIRA CLASSIFIER PARAMETERS String[] paramA
* = line.split("\\s+");
- *
+ *
* if( paramA[0].equals("-classifierParams") ) { String classifierParam = ""; for(int p=1;
* p<=paramA.length-1; p++) classifierParam += paramA[p]+" ";
- *
+ *
* if(paramA.length>=2) { String[] tmpParamA = new String[2]; tmpParamA[0] = paramA[0];
* tmpParamA[1] = classifierParam; paramA = tmpParamA; } else {
* println("Malformed line in config file:"); println(origLine); System.exit(70); } }//END
@@ -2559,12 +2551,12 @@ public class MIRACore {
/*
* 1: -docSet bottom 8d 2: -docSet bottom 25% the bottom ceil(0.20*numDocs) documents 3: -docSet
* top 8d 4: -docSet top 25% the top ceil(0.20*numDocs) documents
- *
+ *
* 5: -docSet window 11d around 90percentile 11 docs centered around 80th percentile (complain
* if not enough docs; don't adjust) 6: -docSet window 11d around 40rank 11 docs centered around
* doc ranked 50 (complain if not enough docs; don't adjust)
- *
- *
+ *
+ *
* [0]: method (0-6) [1]: first (1-indexed) [2]: last (1-indexed) [3]: size [4]: center [5]:
* arg1 (-1 for method 0) [6]: arg2 (-1 for methods 0-4)
*/
@@ -2723,10 +2715,10 @@ public class MIRACore {
} else {
nextIndex = 1;
}
- int lineCount = countLines(prefix + nextIndex);
+ int lineCount = new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines();
for (int r = 0; r < numFiles; ++r) {
- if (countLines(prefix + nextIndex) != lineCount) {
+ if (new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines() != lineCount) {
throw new RuntimeException("Line count mismatch in " + (prefix + nextIndex) + ".");
}
InputStream inStream = new FileInputStream(new File(prefix + nextIndex));
@@ -2887,109 +2879,11 @@ public class MIRACore {
return str;
}
- private int countLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
- private int countNonEmptyLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null && line.length() > 0)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
private String fullPath(String dir, String fileName) {
File dummyFile = new File(dir, fileName);
return dummyFile.getAbsolutePath();
}
- @SuppressWarnings("unused")
- private void cleanupMemory() {
- cleanupMemory(100, false);
- }
-
- @SuppressWarnings("unused")
- private void cleanupMemorySilently() {
- cleanupMemory(100, true);
- }
-
- @SuppressWarnings("static-access")
- private void cleanupMemory(int reps, boolean silent) {
- int bytesPerMB = 1024 * 1024;
-
- long totalMemBefore = myRuntime.totalMemory();
- long freeMemBefore = myRuntime.freeMemory();
- long usedMemBefore = totalMemBefore - freeMemBefore;
-
- long usedCurr = usedMemBefore;
- long usedPrev = usedCurr;
-
- // perform garbage collection repeatedly, until there is no decrease in
- // the amount of used memory
- for (int i = 1; i <= reps; ++i) {
- myRuntime.runFinalization();
- myRuntime.gc();
- (Thread.currentThread()).yield();
-
- usedPrev = usedCurr;
- usedCurr = myRuntime.totalMemory() - myRuntime.freeMemory();
-
- if (usedCurr == usedPrev)
- break;
- }
-
- if (!silent) {
- long totalMemAfter = myRuntime.totalMemory();
- long freeMemAfter = myRuntime.freeMemory();
- long usedMemAfter = totalMemAfter - freeMemAfter;
-
- println("GC: d_used = " + ((usedMemAfter - usedMemBefore) / bytesPerMB) + " MB "
- + "(d_tot = " + ((totalMemAfter - totalMemBefore) / bytesPerMB) + " MB).", 2);
- }
- }
-
- @SuppressWarnings("unused")
- private void printMemoryUsage() {
- int bytesPerMB = 1024 * 1024;
- long totalMem = myRuntime.totalMemory();
- long freeMem = myRuntime.freeMemory();
- long usedMem = totalMem - freeMem;
-
- println("Allocated memory: " + (totalMem / bytesPerMB) + " MB " + "(of which "
- + (usedMem / bytesPerMB) + " MB is being used).", 2);
- }
-
private void println(Object obj, int priority) {
if (priority <= verbosity)
println(obj);
@@ -3008,20 +2902,12 @@ public class MIRACore {
System.out.print(obj);
}
- @SuppressWarnings("unused")
- private void showProgress() {
- ++progress;
- if (progress % 100000 == 0)
- print(".", 2);
- }
-
private ArrayList<Double> randomLambda() {
ArrayList<Double> retLambda = new ArrayList<Double>(1 + numParams);
for (int c = 1; c <= numParams; ++c) {
if (isOptimizable[c]) {
double randVal = randGen.nextDouble(); // number in [0.0,1.0]
- ++generatedRands;
randVal = randVal * (maxRandValue[c] - minRandValue[c]); // number in [0.0,max-min]
randVal = minRandValue[c] + randVal; // number in [min,max]
retLambda.set(c, randVal);
@@ -3032,81 +2918,4 @@ public class MIRACore {
return retLambda;
}
-
- private double[] randomPerturbation(double[] origLambda, int i, double method, double param,
- double mult) {
- double sigma = 0.0;
- if (method == 1) {
- sigma = 1.0 / Math.pow(i, param);
- } else if (method == 2) {
- sigma = Math.exp(-param * i);
- } else if (method == 3) {
- sigma = Math.max(0.0, 1.0 - (i / param));
- }
-
- sigma = mult * sigma;
-
- double[] retLambda = new double[1 + numParams];
-
- for (int c = 1; c <= numParams; ++c) {
- if (isOptimizable[c]) {
- double randVal = 2 * randGen.nextDouble() - 1.0; // number in [-1.0,1.0]
- ++generatedRands;
- randVal = randVal * sigma; // number in [-sigma,sigma]
- randVal = randVal * origLambda[c]; // number in [-sigma*orig[c],sigma*orig[c]]
- randVal = randVal + origLambda[c]; // number in
- // [orig[c]-sigma*orig[c],orig[c]+sigma*orig[c]]
- // = [orig[c]*(1-sigma),orig[c]*(1+sigma)]
- retLambda[c] = randVal;
- } else {
- retLambda[c] = origLambda[c];
- }
- }
-
- return retLambda;
- }
-
- @SuppressWarnings("unused")
- private HashSet<Integer> indicesToDiscard(double[] slope, double[] offset) {
- // some lines can be eliminated: the ones that have a lower offset
- // than some other line with the same slope.
- // That is, for any k1 and k2:
- // if slope[k1] = slope[k2] and offset[k1] > offset[k2],
- // then k2 can be eliminated.
- // (This is actually important to do as it eliminates a bug.)
- // print("discarding: ",4);
-
- int numCandidates = slope.length;
- HashSet<Integer> discardedIndices = new HashSet<Integer>();
- HashMap<Double, Integer> indicesOfSlopes = new HashMap<Double, Integer>();
- // maps slope to index of best candidate that has that slope.
- // ("best" as in the one with the highest offset)
-
- for (int k1 = 0; k1 < numCandidates; ++k1) {
- double currSlope = slope[k1];
- if (!indicesOfSlopes.containsKey(currSlope)) {
- indicesOfSlopes.put(currSlope, k1);
- } else {
- int existingIndex = indicesOfSlopes.get(currSlope);
- if (offset[existingIndex] > offset[k1]) {
- discardedIndices.add(k1);
- // print(k1 + " ",4);
- } else if (offset[k1] > offset[existingIndex]) {
- indicesOfSlopes.put(currSlope, k1);
- discardedIndices.add(existingIndex);
- // print(existingIndex + " ",4);
- }
- }
- }
-
- // old way of doing it; takes quadratic time (vs. linear time above)
- /*
- * for (int k1 = 0; k1 < numCandidates; ++k1) { for (int k2 = 0; k2 < numCandidates; ++k2) { if
- * (k1 != k2 && slope[k1] == slope[k2] && offset[k1] > offset[k2]) { discardedIndices.add(k2);
- * // print(k2 + " ",4); } } }
- */
-
- // println("",4);
- return discardedIndices;
- } // indicesToDiscard(double[] slope, double[] offset)
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/mira/Optimizer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/mira/Optimizer.java b/src/main/java/org/apache/joshua/mira/Optimizer.java
index 6eaced4..f51a5b3 100755
--- a/src/main/java/org/apache/joshua/mira/Optimizer.java
+++ b/src/main/java/org/apache/joshua/mira/Optimizer.java
@@ -18,9 +18,9 @@
*/
package org.apache.joshua.mira;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -81,8 +81,6 @@ public class Optimizer {
String[] featInfo;
int thisBatchSize = 0;
int numBatch = 0;
- int numUpdate = 0;
- Iterator it;
Integer diffFeatId;
// update weights
@@ -98,7 +96,7 @@ public class Optimizer {
s = sents.get(sentCount);
// find out oracle and prediction
findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
-
+
// the model scores here are already scaled in findOraPred
oraMetric = oraPredScore[0];
oraScore = oraPredScore[1];
@@ -106,7 +104,7 @@ public class Optimizer {
predScore = oraPredScore[3];
oraFeat = oraPredFeat[0];
predFeat = oraPredFeat[1];
-
+
// update the scale
if (needScale) { // otherwise featscale remains 1.0
sumMetricScore += java.lang.Math.abs(oraMetric + predMetric);
@@ -119,7 +117,7 @@ public class Optimizer {
vecOraFeat = oraFeat.split("\\s+");
vecPredFeat = predFeat.split("\\s+");
-
+
//accumulate difference feature vector
if ( b == 0 ) {
for (int i = 0; i < vecOraFeat.length; i++) {
@@ -185,8 +183,8 @@ public class Optimizer {
if (!runPercep) { // otherwise eta=1.0
featNorm = 0;
Collection<Double> allDiff = featDiff.values();
- for (it = allDiff.iterator(); it.hasNext();) {
- diff = (Double) it.next();
+ for (Iterator<Double> it = allDiff.iterator(); it.hasNext();) {
+ diff = it.next();
featNorm += diff * diff / ( thisBatchSize * thisBatchSize );
}
}
@@ -199,10 +197,10 @@ public class Optimizer {
}
avgEta += eta;
Set<Integer> diffFeatSet = featDiff.keySet();
- it = diffFeatSet.iterator();
+ Iterator<Integer> it = diffFeatSet.iterator();
if ( java.lang.Math.abs(eta) > 1e-20 ) {
while (it.hasNext()) {
- diffFeatId = (Integer) it.next();
+ diffFeatId = it.next();
finalLambda[diffFeatId] =
finalLambda[diffFeatId] + eta * featDiff.get(diffFeatId) / thisBatchSize;
}
@@ -304,7 +302,7 @@ public class Optimizer {
// find out the 1-best candidate for each sentence
// this depends on the training mode
maxModelScore = NegInf;
- for (Iterator it = candSet.iterator(); it.hasNext();) {
+ for (Iterator<String> it = candSet.iterator(); it.hasNext();) {
modelScore = 0.0;
candStr = it.next().toString();
feat_str = feat_hash[i].get(candStr).split("\\s+");
@@ -363,7 +361,7 @@ public class Optimizer {
worstPredScore = PosInf;
}
- for (Iterator it = candSet.iterator(); it.hasNext();) {
+ for (Iterator<String> it = candSet.iterator(); it.hasNext();) {
cand = it.next().toString();
candMetric = computeSentMetric(sentId, cand); // compute metric score
@@ -605,11 +603,11 @@ public class Optimizer {
}
}
}
-
+
public double getMetricScore() {
return finalMetricScore;
}
-
+
private Vector<String> output;
private double[] initialLambda;
private double[] finalLambda;
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/pro/PROCore.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/pro/PROCore.java b/src/main/java/org/apache/joshua/pro/PROCore.java
index ec23e0a..5dc3311 100755
--- a/src/main/java/org/apache/joshua/pro/PROCore.java
+++ b/src/main/java/org/apache/joshua/pro/PROCore.java
@@ -31,6 +31,7 @@ import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
+import java.nio.charset.StandardCharsets;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
@@ -49,7 +50,7 @@ import org.apache.joshua.decoder.Decoder;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.metrics.EvaluationMetric;
import org.apache.joshua.util.StreamGobbler;
-
+import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -65,19 +66,15 @@ public class PROCore {
private TreeSet<Integer>[] indicesOfInterest_all;
private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
private final static double NegInf = (-1.0 / 0.0);
private final static double PosInf = (+1.0 / 0.0);
private final static double epsilon = 1.0 / 1000000;
- private int progress;
-
private int verbosity; // anything of priority <= verbosity will be printed
// (lower value for priority means more important)
private Random randGen;
- private int generatedRands;
private int numSentences;
// number of sentences in the dev set
@@ -256,28 +253,27 @@ public class PROCore {
this.joshuaConfiguration = joshuaConfiguration;
}
- public PROCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
+ public PROCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(args);
initialize(0);
}
- public PROCore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
+ public PROCore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(cfgFileToArgsArray(configFileName));
initialize(0);
}
- private void initialize(int randsToSkip) {
+ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
randGen = new Random(seed);
for (int r = 1; r <= randsToSkip; ++r) {
randGen.nextDouble();
}
- generatedRands = randsToSkip;
if (randsToSkip == 0) {
println("----------------------------------------------------", 1);
@@ -291,7 +287,7 @@ public class PROCore {
// COUNT THE TOTAL NUM OF SENTENCES TO BE DECODED, refFileName IS THE COMBINED REFERENCE FILE
// NAME(AUTO GENERATED)
- numSentences = countLines(refFileName) / refsPerSen;
+ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
// ??
processDocInfo();
@@ -304,7 +300,7 @@ public class PROCore {
set_docSubsetInfo(docSubsetInfo);
// count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
+ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
numParamsOld = numParams;
// read parameter config file
@@ -910,7 +906,6 @@ public class PROCore {
for (String featurePair : featVal_str) {
String[] pair = featurePair.split("=");
String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
int featId = Vocabulary.id(name);
// need to identify newly fired feats here
if (featId > numParams) {
@@ -1573,8 +1568,6 @@ public class PROCore {
BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
PrintWriter outFile = new PrintWriter(cfgFileName);
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
int origFeatNum = 0; // feat num in the template file
String line = inFile.readLine();
@@ -1776,7 +1769,7 @@ public class PROCore {
// belongs to,
// and its order in that document. (can also use '-' instead of '_')
- int docInfoSize = countNonEmptyLines(docInfoFileName);
+ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
if (docInfoSize < numSentences) { // format #1 or #2
numDocuments = docInfoSize;
@@ -1969,9 +1962,7 @@ public class PROCore {
Vector<String> argsVector = new Vector<String>();
- BufferedReader inFile = null;
- try {
- inFile = new BufferedReader(new FileReader(fileName));
+ try (BufferedReader inFile = new BufferedReader(new FileReader(fileName));) {
String line, origLine;
do {
line = inFile.readLine();
@@ -2044,8 +2035,6 @@ public class PROCore {
}
} while (line != null);
-
- inFile.close();
} catch (FileNotFoundException e) {
println("PRO configuration file " + fileName + " was not found!");
throw new RuntimeException(e);
@@ -2626,9 +2615,7 @@ public class PROCore {
outFileName = prefix + ".all";
}
- try {
- PrintWriter outFile = new PrintWriter(outFileName);
-
+ try (PrintWriter outFile = new PrintWriter(outFileName);) {
BufferedReader[] inFile = new BufferedReader[numFiles];
int nextIndex;
@@ -2638,14 +2625,14 @@ public class PROCore {
} else {
nextIndex = 1;
}
- int lineCount = countLines(prefix + nextIndex);
+ int lineCount = new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines();
for (int r = 0; r < numFiles; ++r) {
- if (countLines(prefix + nextIndex) != lineCount) {
+ if (new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines() != lineCount) {
throw new RuntimeException("Line count mismatch in " + (prefix + nextIndex) + ".");
}
InputStream inStream = new FileInputStream(new File(prefix + nextIndex));
- inFile[r] = new BufferedReader(new InputStreamReader(inStream, "utf8"));
+ inFile[r] = new BufferedReader(new InputStreamReader(inStream, StandardCharsets.UTF_8));
++nextIndex;
}
@@ -2658,8 +2645,6 @@ public class PROCore {
}
}
- outFile.close();
-
for (int r = 0; r < numFiles; ++r) {
inFile[r].close();
}
@@ -2802,109 +2787,11 @@ public class PROCore {
return str;
}
- private int countLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
- private int countNonEmptyLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null && line.length() > 0)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
private String fullPath(String dir, String fileName) {
File dummyFile = new File(dir, fileName);
return dummyFile.getAbsolutePath();
}
- @SuppressWarnings("unused")
- private void cleanupMemory() {
- cleanupMemory(100, false);
- }
-
- @SuppressWarnings("unused")
- private void cleanupMemorySilently() {
- cleanupMemory(100, true);
- }
-
- @SuppressWarnings("static-access")
- private void cleanupMemory(int reps, boolean silent) {
- int bytesPerMB = 1024 * 1024;
-
- long totalMemBefore = myRuntime.totalMemory();
- long freeMemBefore = myRuntime.freeMemory();
- long usedMemBefore = totalMemBefore - freeMemBefore;
-
- long usedCurr = usedMemBefore;
- long usedPrev = usedCurr;
-
- // perform garbage collection repeatedly, until there is no decrease in
- // the amount of used memory
- for (int i = 1; i <= reps; ++i) {
- myRuntime.runFinalization();
- myRuntime.gc();
- (Thread.currentThread()).yield();
-
- usedPrev = usedCurr;
- usedCurr = myRuntime.totalMemory() - myRuntime.freeMemory();
-
- if (usedCurr == usedPrev)
- break;
- }
-
- if (!silent) {
- long totalMemAfter = myRuntime.totalMemory();
- long freeMemAfter = myRuntime.freeMemory();
- long usedMemAfter = totalMemAfter - freeMemAfter;
-
- println("GC: d_used = " + ((usedMemAfter - usedMemBefore) / bytesPerMB) + " MB "
- + "(d_tot = " + ((totalMemAfter - totalMemBefore) / bytesPerMB) + " MB).", 2);
- }
- }
-
- @SuppressWarnings("unused")
- private void printMemoryUsage() {
- int bytesPerMB = 1024 * 1024;
- long totalMem = myRuntime.totalMemory();
- long freeMem = myRuntime.freeMemory();
- long usedMem = totalMem - freeMem;
-
- println("Allocated memory: " + (totalMem / bytesPerMB) + " MB " + "(of which "
- + (usedMem / bytesPerMB) + " MB is being used).", 2);
- }
-
private void println(Object obj, int priority) {
if (priority <= verbosity)
println(obj);
@@ -2923,20 +2810,12 @@ public class PROCore {
System.out.print(obj);
}
- @SuppressWarnings("unused")
- private void showProgress() {
- ++progress;
- if (progress % 100000 == 0)
- print(".", 2);
- }
-
private ArrayList<Double> randomLambda() {
ArrayList<Double> retLambda = new ArrayList<Double>(1 + numParams);
for (int c = 1; c <= numParams; ++c) {
if (isOptimizable[c]) {
double randVal = randGen.nextDouble(); // number in [0.0,1.0]
- ++generatedRands;
randVal = randVal * (maxRandValue[c] - minRandValue[c]); // number in [0.0,max-min]
randVal = minRandValue[c] + randVal; // number in [min,max]
retLambda.set(c, randVal);
@@ -2947,81 +2826,4 @@ public class PROCore {
return retLambda;
}
-
- private double[] randomPerturbation(double[] origLambda, int i, double method, double param,
- double mult) {
- double sigma = 0.0;
- if (method == 1) {
- sigma = 1.0 / Math.pow(i, param);
- } else if (method == 2) {
- sigma = Math.exp(-param * i);
- } else if (method == 3) {
- sigma = Math.max(0.0, 1.0 - (i / param));
- }
-
- sigma = mult * sigma;
-
- double[] retLambda = new double[1 + numParams];
-
- for (int c = 1; c <= numParams; ++c) {
- if (isOptimizable[c]) {
- double randVal = 2 * randGen.nextDouble() - 1.0; // number in [-1.0,1.0]
- ++generatedRands;
- randVal = randVal * sigma; // number in [-sigma,sigma]
- randVal = randVal * origLambda[c]; // number in [-sigma*orig[c],sigma*orig[c]]
- randVal = randVal + origLambda[c]; // number in
- // [orig[c]-sigma*orig[c],orig[c]+sigma*orig[c]]
- // = [orig[c]*(1-sigma),orig[c]*(1+sigma)]
- retLambda[c] = randVal;
- } else {
- retLambda[c] = origLambda[c];
- }
- }
-
- return retLambda;
- }
-
- @SuppressWarnings("unused")
- private HashSet<Integer> indicesToDiscard(double[] slope, double[] offset) {
- // some lines can be eliminated: the ones that have a lower offset
- // than some other line with the same slope.
- // That is, for any k1 and k2:
- // if slope[k1] = slope[k2] and offset[k1] > offset[k2],
- // then k2 can be eliminated.
- // (This is actually important to do as it eliminates a bug.)
- // print("discarding: ",4);
-
- int numCandidates = slope.length;
- HashSet<Integer> discardedIndices = new HashSet<Integer>();
- HashMap<Double, Integer> indicesOfSlopes = new HashMap<Double, Integer>();
- // maps slope to index of best candidate that has that slope.
- // ("best" as in the one with the highest offset)
-
- for (int k1 = 0; k1 < numCandidates; ++k1) {
- double currSlope = slope[k1];
- if (!indicesOfSlopes.containsKey(currSlope)) {
- indicesOfSlopes.put(currSlope, k1);
- } else {
- int existingIndex = indicesOfSlopes.get(currSlope);
- if (offset[existingIndex] > offset[k1]) {
- discardedIndices.add(k1);
- // print(k1 + " ",4);
- } else if (offset[k1] > offset[existingIndex]) {
- indicesOfSlopes.put(currSlope, k1);
- discardedIndices.add(existingIndex);
- // print(existingIndex + " ",4);
- }
- }
- }
-
- // old way of doing it; takes quadratic time (vs. linear time above)
- /*
- * for (int k1 = 0; k1 < numCandidates; ++k1) { for (int k2 = 0; k2 < numCandidates; ++k2) { if
- * (k1 != k2 && slope[k1] == slope[k2] && offset[k1] > offset[k2]) { discardedIndices.add(k2);
- * // print(k2 + " ",4); } } }
- */
-
- // println("",4);
- return discardedIndices;
- } // indicesToDiscard(double[] slope, double[] offset)
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/tools/GrammarPacker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/tools/GrammarPacker.java b/src/main/java/org/apache/joshua/tools/GrammarPacker.java
index b39b775..e416d03 100644
--- a/src/main/java/org/apache/joshua/tools/GrammarPacker.java
+++ b/src/main/java/org/apache/joshua/tools/GrammarPacker.java
@@ -54,13 +54,13 @@ public class GrammarPacker {
/**
* The packed grammar version number. Increment this any time you add new features, and update
* the documentation.
- *
+ *
* Version history:
- *
+ *
* - 3 (May 2016). This was the first version that was marked. It removed the special phrase-
* table packing that packed phrases without the [X,1] on the source and target sides, which
* then required special handling in the decoder to use for phrase-based decoding.
- *
+ *
* - 4 (August 2016). Phrase-based decoding rewritten to represent phrases without a builtin
* nonterminal. Instead, cost-less glue rules are used in phrase-based decoding. This eliminates
* the need for special handling of phrase grammars (except for having to add a LHS), and lets
@@ -68,7 +68,7 @@ public class GrammarPacker {
*
*/
public static final int VERSION = 4;
-
+
// Size limit for slice in bytes.
private static int DATA_SIZE_LIMIT = (int) (Integer.MAX_VALUE * 0.8);
// Estimated average number of feature entries for one rule.
@@ -145,30 +145,30 @@ public class GrammarPacker {
}
private void readConfig(String config_filename) throws IOException {
- LineReader reader = new LineReader(config_filename);
- while (reader.hasNext()) {
- // Clean up line, chop comments off and skip if the result is empty.
- String line = reader.next().trim();
- if (line.indexOf('#') != -1)
- line = line.substring(0, line.indexOf('#'));
- if (line.isEmpty())
- continue;
- String[] fields = line.split("[\\s]+");
-
- if (fields.length < 2) {
- throw new RuntimeException("Incomplete line in config.");
- }
- if ("slice_size".equals(fields[0])) {
- // Number of records to concurrently load into memory for sorting.
- approximateMaximumSliceSize = Integer.parseInt(fields[1]);
+ try(LineReader reader = new LineReader(config_filename);) {
+ while (reader.hasNext()) {
+ // Clean up line, chop comments off and skip if the result is empty.
+ String line = reader.next().trim();
+ if (line.indexOf('#') != -1)
+ line = line.substring(0, line.indexOf('#'));
+ if (line.isEmpty())
+ continue;
+ String[] fields = line.split("[\\s]+");
+
+ if (fields.length < 2) {
+ throw new RuntimeException("Incomplete line in config.");
+ }
+ if ("slice_size".equals(fields[0])) {
+ // Number of records to concurrently load into memory for sorting.
+ approximateMaximumSliceSize = Integer.parseInt(fields[1]);
+ }
}
}
- reader.close();
}
/**
* Executes the packing.
- *
+ *
* @throws IOException if there is an error reading the grammar
*/
public void pack() throws IOException {
@@ -226,23 +226,24 @@ public class GrammarPacker {
/**
* Returns a reader that turns whatever file format is found into Hiero grammar rules.
- *
+ *
* @param grammarFile
* @return
* @throws IOException
*/
private HieroFormatReader getGrammarReader() throws IOException {
- LineReader reader = new LineReader(grammar);
- String line = reader.next();
- if (line.startsWith("[")) {
- return new HieroFormatReader(grammar);
- } else {
- return new MosesFormatReader(grammar);
+ try (LineReader reader = new LineReader(grammar);) {
+ String line = reader.next();
+ if (line.startsWith("[")) {
+ return new HieroFormatReader(grammar);
+ } else {
+ return new MosesFormatReader(grammar);
+ }
}
}
/**
- * This first pass over the grammar
+ * This first pass over the grammar
* @param reader
*/
private void explore(HieroFormatReader reader) {
@@ -258,9 +259,9 @@ public class GrammarPacker {
/* Add symbols to vocabulary.
* NOTE: In case of nonterminals, we add both stripped versions ("[X]")
* and "[X,1]" to the vocabulary.
- *
+ *
* TODO: MJP May 2016: Is it necessary to add [X,1]? This is currently being done in
- * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
+ * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
*/
// Add feature names to vocabulary and pass the value through the
@@ -438,7 +439,7 @@ public class GrammarPacker {
* written simultaneously. The source structure is written into a downward-pointing trie and
* stores the rule's lhs as well as links to the target and feature stream. The feature stream is
* prompted to write out a block
- *
+ *
* @param source_trie
* @param target_trie
* @param feature_buffer
@@ -580,9 +581,9 @@ public class GrammarPacker {
/**
* Integer-labeled, doubly-linked trie with some provisions for packing.
- *
+ *
* @author Juri Ganitkevitch
- *
+ *
* @param <D> The trie's value type.
*/
class PackingTrie<D extends PackingTrieValue> {
@@ -632,10 +633,10 @@ public class GrammarPacker {
* points to children) from upwards pointing (children point to parent) tries, as well as
* skeletal (no data, just the labeled links) and non-skeletal (nodes have a data block)
* packing.
- *
+ *
* @param downwards Are we packing into a downwards-pointing trie?
* @param skeletal Are we packing into a skeletal trie?
- *
+ *
* @return Number of bytes the trie node would occupy.
*/
int size(boolean downwards, boolean skeletal) {
@@ -684,6 +685,7 @@ public class GrammarPacker {
this.target = target;
}
+ @Override
public int size() {
return 3;
}
@@ -696,6 +698,7 @@ public class GrammarPacker {
this.parent = parent;
}
+ @Override
public int size() {
return 0;
}
@@ -753,7 +756,7 @@ public class GrammarPacker {
/**
* Enqueue a data block for later writing.
- *
+ *
* @param block_index The index of the data block to add to writing queue.
* @return The to-be-written block's output index.
*/
@@ -765,7 +768,7 @@ public class GrammarPacker {
/**
* Performs the actual writing to disk in the order specified by calls to write() since the last
* call to initialize().
- *
+ *
* @param out
* @throws IOException
*/
@@ -828,10 +831,11 @@ public class GrammarPacker {
/**
* Add a block of features to the buffer.
- *
+ *
* @param features TreeMap with the features for one rule.
* @return The index of the resulting data block.
*/
+ @Override
int add(TreeMap<Integer, Float> features) {
int data_position = buffer.position();
@@ -871,10 +875,11 @@ public class GrammarPacker {
/**
* Add a rule alignments to the buffer.
- *
+ *
* @param alignments a byte array with the alignment points for one rule.
* @return The index of the resulting data block.
*/
+ @Override
int add(byte[] alignments) {
int data_position = buffer.position();
int size_estimate = alignments.length + 1;
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/tools/LabelPhrases.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/tools/LabelPhrases.java b/src/main/java/org/apache/joshua/tools/LabelPhrases.java
index 2fd2b3f..8f15b0e 100644
--- a/src/main/java/org/apache/joshua/tools/LabelPhrases.java
+++ b/src/main/java/org/apache/joshua/tools/LabelPhrases.java
@@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory;
/**
* Finds labeling for a set of phrases.
- *
+ *
* @author Juri Ganitkevitch
*/
public class LabelPhrases {
@@ -37,7 +37,7 @@ public class LabelPhrases {
/**
* Main method.
- *
+ *
* @param args names of the two grammars to be compared
* @throws IOException if there is an error reading the input grammars
*/
@@ -60,52 +60,52 @@ public class LabelPhrases {
System.exit(-1);
}
- LineReader phrase_reader = new LineReader(phrase_file_name);
-
- while (phrase_reader.ready()) {
- String line = phrase_reader.readLine();
-
- String[] fields = line.split("\\t");
- if (fields.length != 3 || fields[2].equals("()")) {
- System.err.println("[FAIL] Empty parse in line:\t" + line);
- continue;
- }
-
- String[] phrase_strings = fields[0].split("\\s");
- int[] phrase_ids = new int[phrase_strings.length];
- for (int i = 0; i < phrase_strings.length; i++)
- phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
+ try (LineReader phrase_reader = new LineReader(phrase_file_name);) {
+ while (phrase_reader.ready()) {
+ String line = phrase_reader.readLine();
- ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
- int[] sentence_ids = syntax.getTerminals();
+ String[] fields = line.split("\\t");
+ if (fields.length != 3 || fields[2].equals("()")) {
+ System.err.println("[FAIL] Empty parse in line:\t" + line);
+ continue;
+ }
- int match_start = -1;
- int match_end = -1;
- for (int i = 0; i < sentence_ids.length; i++) {
- if (phrase_ids[0] == sentence_ids[i]) {
- match_start = i;
- int j = 0;
- while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
- j++;
- }
- if (j == phrase_ids.length) {
- match_end = i + j;
- break;
+ String[] phrase_strings = fields[0].split("\\s");
+ int[] phrase_ids = new int[phrase_strings.length];
+ for (int i = 0; i < phrase_strings.length; i++)
+ phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
+
+ ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
+ int[] sentence_ids = syntax.getTerminals();
+
+ int match_start = -1;
+ int match_end = -1;
+ for (int i = 0; i < sentence_ids.length; i++) {
+ if (phrase_ids[0] == sentence_ids[i]) {
+ match_start = i;
+ int j = 0;
+ while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
+ j++;
+ }
+ if (j == phrase_ids.length) {
+ match_end = i + j;
+ break;
+ }
}
}
- }
- int label = syntax.getOneConstituent(match_start, match_end);
- if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
- if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
- if (label == 0) {
- System.err.println("[FAIL] No label found in line:\t" + line);
- continue;
- }
+ int label = syntax.getOneConstituent(match_start, match_end);
+ if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
+ if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
+ if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
+ if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
+ if (label == 0) {
+ System.err.println("[FAIL] No label found in line:\t" + line);
+ continue;
+ }
- System.out.println(Vocabulary.word(label) + "\t" + line);
+ System.out.println(Vocabulary.word(label) + "\t" + line);
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/tools/TestSetFilter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/tools/TestSetFilter.java b/src/main/java/org/apache/joshua/tools/TestSetFilter.java
index ecb2e6e..f73f02d 100644
--- a/src/main/java/org/apache/joshua/tools/TestSetFilter.java
+++ b/src/main/java/org/apache/joshua/tools/TestSetFilter.java
@@ -57,7 +57,7 @@ public class TestSetFilter {
acceptedLastSourceSide = false;
lastSourceSide = null;
}
-
+
public String getFilterName() {
if (filter != null)
if (filter instanceof FastFilter)
@@ -109,7 +109,7 @@ public class TestSetFilter {
}
/**
- * Top-level filter, responsible for calling the fast or exact version. Takes the source side
+ * Top-level filter, responsible for calling the fast or exact version. Takes the source side
* of a rule and determines whether there is any sentence in the test set that can match it.
* @param sourceSide an input source sentence
* @return true if is any sentence in the test set can match the source input
@@ -124,11 +124,11 @@ public class TestSetFilter {
return acceptedLastSourceSide;
}
-
+
/**
* Determines whether a rule is an abstract rule. An abstract rule is one that has no terminals on
* its source side.
- *
+ *
* If the rule is abstract, the rule's arity is returned. Otherwise, 0 is returned.
*/
private boolean isAbstract(String source) {
@@ -144,7 +144,7 @@ public class TestSetFilter {
private interface Filter {
/* Tell the filter about a sentence in the test set being filtered to */
public void addSentence(String sentence);
-
+
/* Returns true if the filter permits the specified source side */
public boolean permits(String sourceSide);
}
@@ -155,7 +155,7 @@ public class TestSetFilter {
public FastFilter() {
ngrams = new HashSet<String>();
}
-
+
@Override
public boolean permits(String source) {
for (String chunk : source.split(NT_REGEX)) {
@@ -194,7 +194,7 @@ public class TestSetFilter {
public LooseFilter() {
testSentences = new ArrayList<String>();
}
-
+
@Override
public void addSentence(String source) {
testSentences.add(source);
@@ -227,13 +227,13 @@ public class TestSetFilter {
private FastFilter fastFilter = null;
private Map<String, Set<Integer>> sentencesByWord;
List<String> testSentences = null;
-
+
public ExactFilter() {
fastFilter = new FastFilter();
sentencesByWord = new HashMap<String, Set<Integer>>();
testSentences = new ArrayList<String>();
}
-
+
@Override
public void addSentence(String source) {
fastFilter.addSentence(source);
@@ -243,13 +243,13 @@ public class TestSetFilter {
/**
* Always permit abstract rules. Otherwise, query the fast filter, and if that passes, apply
- *
+ *
*/
@Override
public boolean permits(String sourceSide) {
if (isAbstract(sourceSide))
return true;
-
+
if (fastFilter.permits(sourceSide)) {
Pattern pattern = getPattern(sourceSide);
for (int i : getSentencesForRule(sourceSide)) {
@@ -257,10 +257,10 @@ public class TestSetFilter {
return true;
}
}
- }
+ }
return false;
}
-
+
protected Pattern getPattern(String source) {
String pattern = Pattern.quote(source);
pattern = pattern.replaceAll(NT_REGEX, "\\\\E.+\\\\Q");
@@ -268,7 +268,7 @@ public class TestSetFilter {
pattern = "(?:^|\\s)" + pattern + "(?:$|\\s)";
return Pattern.compile(pattern);
}
-
+
/*
* Map words to all the sentences they appear in.
*/
@@ -280,7 +280,7 @@ public class TestSetFilter {
sentencesByWord.get(t).add(index);
}
}
-
+
private Set<Integer> getSentencesForRule(String source) {
Set<Integer> sentences = null;
for (String token : source.split("\\s+")) {
@@ -293,7 +293,7 @@ public class TestSetFilter {
}
}
}
-
+
return sentences;
}
}
@@ -311,7 +311,7 @@ public class TestSetFilter {
System.err.println(" -n max n-gram to compare to (default 12)");
return;
}
-
+
String grammarFile = null;
TestSetFilter filter = new TestSetFilter();
@@ -350,34 +350,35 @@ public class TestSetFilter {
System.err.println(String.format("Filtering rules with the %s filter...", filter.getFilterName()));
// System.err.println("Using at max " + filter.RULE_LENGTH + " n-grams...");
}
- LineReader reader = (grammarFile != null)
+ try(LineReader reader = (grammarFile != null)
? new LineReader(grammarFile, filter.verbose)
- : new LineReader(System.in);
- for (String rule: reader) {
- rulesIn++;
-
- String[] parts = P_DELIM.split(rule);
- if (parts.length >= 4) {
- // the source is the second field for thrax grammars, first field for phrasal ones
- String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
- if (filter.inTestSet(source)) {
- System.out.println(rule);
- if (filter.parallel)
+ : new LineReader(System.in);) {
+ for (String rule: reader) {
+ rulesIn++;
+
+ String[] parts = P_DELIM.split(rule);
+ if (parts.length >= 4) {
+ // the source is the second field for thrax grammars, first field for phrasal ones
+ String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
+ if (filter.inTestSet(source)) {
+ System.out.println(rule);
+ if (filter.parallel)
+ System.out.flush();
+ rulesOut++;
+ } else if (filter.parallel) {
+ System.out.println("");
System.out.flush();
- rulesOut++;
- } else if (filter.parallel) {
- System.out.println("");
- System.out.flush();
+ }
}
}
- }
- if (filter.verbose) {
- System.err.println("[INFO] Total rules read: " + rulesIn);
- System.err.println("[INFO] Rules kept: " + rulesOut);
- System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
- System.err.println("[INFO] cached queries: " + filter.cached);
- }
+ if (filter.verbose) {
+ System.err.println("[INFO] Total rules read: " + rulesIn);
+ System.err.println("[INFO] Rules kept: " + rulesOut);
+ System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
+ System.err.println("[INFO] cached queries: " + filter.cached);
+ }
- return;
+ return;
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/BotMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/BotMap.java b/src/main/java/org/apache/joshua/util/BotMap.java
deleted file mode 100644
index 1cc82b5..0000000
--- a/src/main/java/org/apache/joshua/util/BotMap.java
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Gets a special map that maps any key to the a particular value.
- *
- * @author Lane Schwartz
- * @see "Lopez (2008), footnote 9 on p73"
- */
-public class BotMap<K, V> implements Map<K, V> {
-
- /** Special value, which this map will return for every key. */
- private final V value;
-
- /**
- * Constructs a special map that maps any key to the a particular value.
- *
- * @param value Special value, which this map will return for every key.
- */
- public BotMap(V value) {
- this.value = value;
- }
-
- public void clear() {
- throw new UnsupportedOperationException();
- }
-
- public boolean containsKey(Object key) {
- return true;
- }
-
- public boolean containsValue(Object value) {
- return this.value == value;
- }
-
- public Set<Map.Entry<K, V>> entrySet() {
- throw new UnsupportedOperationException();
- }
-
- public V get(Object key) {
- return value;
- }
-
- public boolean isEmpty() {
- return false;
- }
-
- public Set<K> keySet() {
- throw new UnsupportedOperationException();
- }
-
- public V put(K key, V value) {
- throw new UnsupportedOperationException();
- }
-
- public void putAll(Map<? extends K, ? extends V> t) {
- throw new UnsupportedOperationException();
- }
-
- public V remove(Object key) {
- throw new UnsupportedOperationException();
- }
-
- public int size() {
- throw new UnsupportedOperationException();
- }
-
- public Collection<V> values() {
- return Collections.singleton(value);
- }
-
-}
[09/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/pro/PROCore.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/pro/PROCore.java
index 39b34a5,0000000..aba9d6b
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/pro/PROCore.java
+++ b/joshua-core/src/main/java/org/apache/joshua/pro/PROCore.java
@@@ -1,3027 -1,0 +1,2829 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.pro;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
++import java.nio.charset.StandardCharsets;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.TreeSet;
+import java.util.Vector;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.metrics.EvaluationMetric;
+import org.apache.joshua.util.StreamGobbler;
-
++import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This code was originally written by Yuan Cao, who copied the MERT code to produce this file.
+ */
+
+public class PROCore {
+
+ private static final Logger LOG = LoggerFactory.getLogger(PROCore.class);
+
+ private final JoshuaConfiguration joshuaConfiguration;
+ private TreeSet<Integer>[] indicesOfInterest_all;
+
+ private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+ private final static double epsilon = 1.0 / 1000000;
+
- private int progress;
-
+ private int verbosity; // anything of priority <= verbosity will be printed
+ // (lower value for priority means more important)
+
+ private Random randGen;
- private int generatedRands;
+
+ private int numSentences;
+ // number of sentences in the dev set
+ // (aka the "MERT training" set)
+
+ private int numDocuments;
+ // number of documents in the dev set
+ // this should be 1, unless doing doc-level optimization
+
+ private int[] docOfSentence;
+ // docOfSentence[i] stores which document contains the i'th sentence.
+ // docOfSentence is 0-indexed, as are the documents (i.e. first doc is indexed 0)
+
+ private int[] docSubsetInfo;
+ // stores information regarding which subset of the documents are evaluated
+ // [0]: method (0-6)
+ // [1]: first (1-indexed)
+ // [2]: last (1-indexed)
+ // [3]: size
+ // [4]: center
+ // [5]: arg1
+ // [6]: arg2
+ // [1-6] are 0 for method 0, [6] is 0 for methods 1-4 as well
+ // only [1] and [2] are needed for optimization. The rest are only needed for an output message.
+
+ private int refsPerSen;
+ // number of reference translations per sentence
+
+ private int textNormMethod;
+ // 0: no normalization, 1: "NIST-style" tokenization, and also rejoin 'm, 're, *'s, 've, 'll, 'd,
+ // and n't,
+ // 2: apply 1 and also rejoin dashes between letters, 3: apply 1 and also drop non-ASCII
+ // characters
+ // 4: apply 1+2+3
+
+ private int numParams;
+ // total number of firing features
+ // this number may increase overtime as new n-best lists are decoded
+ // initially it is equal to the # of params in the parameter config file
+ private int numParamsOld;
+ // number of features before observing the new features fired in the current iteration
+
+ private double[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ /* *********************************************************** */
+ /* NOTE: indexing starts at 1 in the following few arrays: */
+ /* *********************************************************** */
+
+ // private double[] lambda;
+ private ArrayList<Double> lambda = new ArrayList<Double>();
+ // the current weight vector. NOTE: indexing starts at 1.
+ private ArrayList<Double> bestLambda = new ArrayList<Double>();
+ // the best weight vector across all iterations
+
+ private boolean[] isOptimizable;
+ // isOptimizable[c] = true iff lambda[c] should be optimized
+
+ private double[] minRandValue;
+ private double[] maxRandValue;
+ // when choosing a random value for the lambda[c] parameter, it will be
+ // chosen from the [minRandValue[c],maxRandValue[c]] range.
+ // (*) minRandValue and maxRandValue must be real values, but not -Inf or +Inf
+
+ private double[] defaultLambda;
+ // "default" parameter values; simply the values read in the parameter file
+ // USED FOR NON-OPTIMIZABLE (FIXED) FEATURES
+
+ /* *********************************************************** */
+ /* *********************************************************** */
+
+ private Decoder myDecoder;
+ // COMMENT OUT if decoder is not Joshua
+
+ private String decoderCommand;
+ // the command that runs the decoder; read from decoderCommandFileName
+
+ private int decVerbosity;
+ // verbosity level for decoder output. If 0, decoder output is ignored.
+ // If 1, decoder output is printed.
+
+ private int validDecoderExitValue;
+ // return value from running the decoder command that indicates success
+
+ private int numOptThreads;
+ // number of threads to run things in parallel
+
+ private int saveInterFiles;
+ // 0: nothing, 1: only configs, 2: only n-bests, 3: both configs and n-bests
+
+ private int compressFiles;
+ // should PRO gzip the large files? If 0, no compression takes place.
+ // If 1, compression is performed on: decoder output files, temp sents files,
+ // and temp feats files.
+
+ private int sizeOfNBest;
+ // size of N-best list generated by decoder at each iteration
+ // (aka simply N, but N is a bad variable name)
+
+ private long seed;
+ // seed used to create random number generators
+
+ private boolean randInit;
+ // if true, parameters are initialized randomly. If false, parameters
+ // are initialized using values from parameter file.
+
+ private int maxMERTIterations, minMERTIterations, prevMERTIterations;
+ // max: maximum number of MERT iterations
+ // min: minimum number of MERT iterations before an early MERT exit
+ // prev: number of previous MERT iterations from which to consider candidates (in addition to
+ // the candidates from the current iteration)
+
+ private double stopSigValue;
+ // early MERT exit if no weight changes by more than stopSigValue
+ // (but see minMERTIterations above and stopMinIts below)
+
+ private int stopMinIts;
+ // some early stopping criterion must be satisfied in stopMinIts *consecutive* iterations
+ // before an early exit (but see minMERTIterations above)
+
+ private boolean oneModificationPerIteration;
+ // if true, each MERT iteration performs at most one parameter modification.
+ // If false, a new MERT iteration starts (i.e. a new N-best list is
+ // generated) only after the previous iteration reaches a local maximum.
+
+ private String metricName;
+ // name of evaluation metric optimized by MERT
+
+ private String metricName_display;
+ // name of evaluation metric optimized by MERT, possibly with "doc-level " prefixed
+
+ private String[] metricOptions;
+ // options for the evaluation metric (e.g. for BLEU, maxGramLength and effLengthMethod)
+
+ private EvaluationMetric evalMetric;
+ // the evaluation metric used by MERT
+
+ private int suffStatsCount;
+ // number of sufficient statistics for the evaluation metric
+
+ private String tmpDirPrefix;
+ // prefix for the PRO.temp.* files
+
+ private boolean passIterationToDecoder;
+ // should the iteration number be passed as an argument to decoderCommandFileName?
+
+ // used for pro
+ private String classifierAlg; // the classification algorithm(percep, megam, maxent ...)
+ private String[] classifierParams = null; // the param array for each classifier
+ private int Tau;
+ private int Xi;
+ private double interCoef;
+ private double metricDiff;
+ private double prevMetricScore = 0; // final metric score of the previous iteration, used only
+ // when returnBest = true
+ private boolean returnBest = false; // return the best weight during tuning
+
+ private String dirPrefix; // where are all these files located?
+ private String paramsFileName, docInfoFileName, finalLambdaFileName;
+ private String sourceFileName, refFileName, decoderOutFileName;
+ private String decoderConfigFileName, decoderCommandFileName;
+ private String fakeFileNameTemplate, fakeFileNamePrefix, fakeFileNameSuffix;
+
+ // e.g. output.it[1-x].someOldRun would be specified as:
+ // output.it?.someOldRun
+ // and we'd have prefix = "output.it" and suffix = ".sameOldRun"
+
+ // private int useDisk;
+
+ public PROCore(JoshuaConfiguration joshuaConfiguration) {
+ this.joshuaConfiguration = joshuaConfiguration;
+ }
+
- public PROCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
++ public PROCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(args);
+ initialize(0);
+ }
+
- public PROCore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
++ public PROCore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(cfgFileToArgsArray(configFileName));
+ initialize(0);
+ }
+
- private void initialize(int randsToSkip) {
++ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
+ println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
+
+ randGen = new Random(seed);
+ for (int r = 1; r <= randsToSkip; ++r) {
+ randGen.nextDouble();
+ }
- generatedRands = randsToSkip;
+
+ if (randsToSkip == 0) {
+ println("----------------------------------------------------", 1);
+ println("Initializing...", 1);
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ println("Random number generator initialized using seed: " + seed, 1);
+ println("", 1);
+ }
+
+ // COUNT THE TOTAL NUM OF SENTENCES TO BE DECODED, refFileName IS THE COMBINED REFERENCE FILE
+ // NAME(AUTO GENERATED)
- numSentences = countLines(refFileName) / refsPerSen;
++ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
+
+ // ??
+ processDocInfo();
+ // sets numDocuments and docOfSentence[]
+
+ if (numDocuments > 1)
+ metricName_display = "doc-level " + metricName;
+
+ // ??
+ set_docSubsetInfo(docSubsetInfo);
+
+ // count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
++ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
+ numParamsOld = numParams;
+
+ // read parameter config file
+ try {
+ // read dense parameter names
+ BufferedReader inFile_names = new BufferedReader(new FileReader(paramsFileName));
+
+ for (int c = 1; c <= numParams; ++c) {
+ String line = "";
+ while (line != null && line.length() == 0) { // skip empty lines
+ line = inFile_names.readLine();
+ }
+
+ // save feature names
+ String paramName = (line.substring(0, line.indexOf("|||"))).trim();
+ Vocabulary.id(paramName);
+ // System.err.println(String.format("VOCAB(%s) = %d", paramName, id));
+ }
+
+ inFile_names.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // the parameter file contains one line per parameter
+ // and one line for the normalization method
+ // indexing starts at 1 in these arrays
+ for (int p = 0; p <= numParams; ++p)
+ lambda.add(new Double(0));
+ bestLambda.add(new Double(0));
+ // why only lambda is a list? because the size of lambda
+ // may increase over time, but other arrays are specified in
+ // the param config file, only used for initialization
+ isOptimizable = new boolean[1 + numParams];
+ minRandValue = new double[1 + numParams];
+ maxRandValue = new double[1 + numParams];
+ defaultLambda = new double[1 + numParams];
+ normalizationOptions = new double[3];
+
+ // read initial param values
+ processParamFile();
+ // sets the arrays declared just above
+
+ // SentenceInfo.createV(); // uncomment ONLY IF using vocabulary implementation of SentenceInfo
+
+ String[][] refSentences = new String[numSentences][refsPerSen];
+
+ try {
+
+ // read in reference sentences
+ InputStream inStream_refs = new FileInputStream(new File(refFileName));
+ BufferedReader inFile_refs = new BufferedReader(new InputStreamReader(inStream_refs, "utf8"));
+
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // read the rth reference translation for the ith sentence
+ refSentences[i][r] = inFile_refs.readLine();
+ }
+ }
+
+ inFile_refs.close();
+
+ // normalize reference sentences
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // normalize the rth reference translation for the ith sentence
+ refSentences[i][r] = normalize(refSentences[i][r], textNormMethod);
+ }
+ }
+
+ // read in decoder command, if any
+ decoderCommand = null;
+ if (decoderCommandFileName != null) {
+ if (fileExists(decoderCommandFileName)) {
+ BufferedReader inFile_comm = new BufferedReader(new FileReader(decoderCommandFileName));
+ decoderCommand = inFile_comm.readLine(); // READ IN DECODE COMMAND
+ inFile_comm.close();
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // set static data members for the EvaluationMetric class
+ EvaluationMetric.set_numSentences(numSentences);
+ EvaluationMetric.set_numDocuments(numDocuments);
+ EvaluationMetric.set_refsPerSen(refsPerSen);
+ EvaluationMetric.set_refSentences(refSentences);
+ EvaluationMetric.set_tmpDirPrefix(tmpDirPrefix);
+
+ evalMetric = EvaluationMetric.getMetric(metricName, metricOptions);
+ // used only if returnBest = true
+ prevMetricScore = evalMetric.getToBeMinimized() ? PosInf : NegInf;
+
+ // length of sufficient statistics
+ // for bleu: suffstatscount=8 (2*ngram+2)
+ suffStatsCount = evalMetric.get_suffStatsCount();
+
+ // set static data members for the IntermediateOptimizer class
+ /*
+ * IntermediateOptimizer.set_MERTparams(numSentences, numDocuments, docOfSentence,
+ * docSubsetInfo, numParams, normalizationOptions, isOptimizable oneModificationPerIteration,
+ * evalMetric, tmpDirPrefix, verbosity);
+ */
+
+ // print info
+ if (randsToSkip == 0) { // i.e. first iteration
+ println("Number of sentences: " + numSentences, 1);
+ println("Number of documents: " + numDocuments, 1);
+ println("Optimizing " + metricName_display, 1);
+
+ /*
+ * print("docSubsetInfo: {", 1); for (int f = 0; f < 6; ++f) print(docSubsetInfo[f] + ", ",
+ * 1); println(docSubsetInfo[6] + "}", 1);
+ */
+
+ println("Number of initial features: " + numParams, 1);
+ print("Initial feature names: {", 1);
+
+ for (int c = 1; c <= numParams; ++c)
+ print("\"" + Vocabulary.word(c) + "\"", 1);
+ println("}", 1);
+ println("", 1);
+
+ // TODO just print the correct info
+ println("c Default value\tOptimizable?\tRand. val. range", 1);
+
+ for (int c = 1; c <= numParams; ++c) {
+ print(c + " " + f4.format(lambda.get(c).doubleValue()) + "\t\t", 1);
+
+ if (!isOptimizable[c]) {
+ println(" No", 1);
+ } else {
+ print(" Yes\t\t", 1);
+ print(" [" + minRandValue[c] + "," + maxRandValue[c] + "]", 1);
+ println("", 1);
+ }
+ }
+
+ println("", 1);
+ print("Weight vector normalization method: ", 1);
+ if (normalizationOptions[0] == 0) {
+ println("none.", 1);
+ } else if (normalizationOptions[0] == 1) {
+ println(
+ "weights will be scaled so that the \""
+ + Vocabulary.word((int) normalizationOptions[2])
+ + "\" weight has an absolute value of " + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 2) {
+ println("weights will be scaled so that the maximum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 3) {
+ println("weights will be scaled so that the minimum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 4) {
+ println("weights will be scaled so that the L-" + normalizationOptions[1] + " norm is "
+ + normalizationOptions[2] + ".", 1);
+ }
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // rename original config file so it doesn't get overwritten
+ // (original name will be restored in finish())
+ renameFile(decoderConfigFileName, decoderConfigFileName + ".PRO.orig");
+ } // if (randsToSkip == 0)
+
+ // by default, load joshua decoder
+ if (decoderCommand == null && fakeFileNameTemplate == null) {
+ println("Loading Joshua decoder...", 1);
+ myDecoder = new Decoder(joshuaConfiguration);
+ println("...finished loading @ " + (new Date()), 1);
+ println("");
+ } else {
+ myDecoder = null;
+ }
+
+ @SuppressWarnings("unchecked")
+ TreeSet<Integer>[] temp_TSA = new TreeSet[numSentences];
+ indicesOfInterest_all = temp_TSA;
+
+ for (int i = 0; i < numSentences; ++i) {
+ indicesOfInterest_all[i] = new TreeSet<Integer>();
+ }
+ } // void initialize(...)
+
+ // -------------------------
+
+ public void run_PRO() {
+ run_PRO(minMERTIterations, maxMERTIterations, prevMERTIterations);
+ }
+
+ public void run_PRO(int minIts, int maxIts, int prevIts) {
+ // FIRST, CLEAN ALL PREVIOUS TEMP FILES
+ String dir;
+ int k = tmpDirPrefix.lastIndexOf("/");
+ if (k >= 0) {
+ dir = tmpDirPrefix.substring(0, k + 1);
+ } else {
+ dir = "./";
+ }
+ String files;
+ File folder = new File(dir);
+
+ if (folder.exists()) {
+ File[] listOfFiles = folder.listFiles();
+
+ for (int i = 0; i < listOfFiles.length; i++) {
+ if (listOfFiles[i].isFile()) {
+ files = listOfFiles[i].getName();
+ if (files.startsWith("PRO.temp")) {
+ deleteFile(files);
+ }
+ }
+ }
+ }
+
+ println("----------------------------------------------------", 1);
+ println("PRO run started @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // if no default lambda is provided
+ if (randInit) {
+ println("Initializing lambda[] randomly.", 1);
+ // initialize optimizable parameters randomly (sampling uniformly from
+ // that parameter's random value range)
+ lambda = randomLambda();
+ }
+
+ println("Initial lambda[]: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ int[] maxIndex = new int[numSentences];
+
+ // HashMap<Integer,int[]>[] suffStats_array = new HashMap[numSentences];
+ // suffStats_array[i] maps candidates of interest for sentence i to an array
+ // storing the sufficient statistics for that candidate
+
+ int earlyStop = 0;
+ // number of consecutive iteration an early stopping criterion was satisfied
+
+ for (int iteration = 1;; ++iteration) {
+
+ // what does "A" contain?
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+ double[] A = run_single_iteration(iteration, minIts, maxIts, prevIts, earlyStop, maxIndex);
+ if (A != null) {
+ earlyStop = (int) A[1];
+ if (A[2] == 1)
+ break;
+ } else {
+ break;
+ }
+
+ } // for (iteration)
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("PRO run ended @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ if (!returnBest)
+ println("FINAL lambda: " + lambdaToString(lambda), 1);
+ // + " (" + metricName_display + ": " + FINAL_score + ")",1);
+ else
+ println("BEST lambda: " + lambdaToString(lambda), 1);
+ // + " (" + metricName_display + ": " + FINAL_score + ")",1);
+
+ // delete intermediate .temp.*.it* decoder output files
+ for (int iteration = 1; iteration <= maxIts; ++iteration) {
+ if (compressFiles == 1) {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration + ".gz");
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration + ".gz");
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz");
+ }
+ } else {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration);
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration);
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+ }
+ }
+ } // void run_PRO(int maxIts)
+
+ // this is the key function!
+ @SuppressWarnings("unchecked")
+ public double[] run_single_iteration(int iteration, int minIts, int maxIts, int prevIts,
+ int earlyStop, int[] maxIndex) {
+ double FINAL_score = 0;
+
+ double[] retA = new double[3];
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+
+ boolean done = false;
+ retA[2] = 1; // will only be made 0 if we don't break from the following loop
+
+ // save feats and stats for all candidates(old & new)
+ HashMap<String, String>[] feat_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ feat_hash[i] = new HashMap<String, String>();
+
+ HashMap<String, String>[] stats_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ stats_hash[i] = new HashMap<String, String>();
+
+ while (!done) { // NOTE: this "loop" will only be carried out once
+ println("--- Starting PRO iteration #" + iteration + " @ " + (new Date()) + " ---", 1);
+
+ // printMemoryUsage();
+
+ /******************************/
+ // CREATE DECODER CONFIG FILE //
+ /******************************/
+
+ createConfigFile(lambda, decoderConfigFileName, decoderConfigFileName + ".PRO.orig");
+ // i.e. use the original config file as a template
+
+ /***************/
+ // RUN DECODER //
+ /***************/
+
+ if (iteration == 1) {
+ println("Decoding using initial weight vector " + lambdaToString(lambda), 1);
+ } else {
+ println("Redecoding using weight vector " + lambdaToString(lambda), 1);
+ }
+
+ // generate the n-best file after decoding
+ String[] decRunResult = run_decoder(iteration); // iteration passed in case fake decoder will
+ // be used
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ if (!decRunResult[1].equals("2")) {
+ println("...finished decoding @ " + (new Date()), 1);
+ }
+
+ checkFile(decRunResult[0]);
+
+ /************* END OF DECODING **************/
+
+ println("Producing temp files for iteration " + iteration, 3);
+
+ produceTempFiles(decRunResult[0], iteration);
+
+ // save intermedidate output files
+ // save joshua.config.pro.it*
+ if (saveInterFiles == 1 || saveInterFiles == 3) { // make copy of intermediate config file
+ if (!copyFile(decoderConfigFileName, decoderConfigFileName + ".PRO.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder config file (to create"
+ + decoderConfigFileName + ".PRO.it" + iteration + ") was unsuccessful!", 1);
+ }
+ }
+
+ // save output.nest.PRO.it*
+ if (saveInterFiles == 2 || saveInterFiles == 3) { // make copy of intermediate decoder output
+ // file...
+
+ if (!decRunResult[1].equals("2")) { // ...but only if no fake decoder
+ if (!decRunResult[0].endsWith(".gz")) {
+ if (!copyFile(decRunResult[0], decRunResult[0] + ".PRO.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder output file (to create"
+ + decRunResult[0] + ".PRO.it" + iteration + ") was unsuccessful!", 1);
+ }
+ } else {
+ String prefix = decRunResult[0].substring(0, decRunResult[0].length() - 3);
+ if (!copyFile(prefix + ".gz", prefix + ".PRO.it" + iteration + ".gz")) {
+ println("Warning: attempt to make copy of decoder output file (to create" + prefix
+ + ".PRO.it" + iteration + ".gz" + ") was unsuccessful!", 1);
+ }
+ }
+
+ if (compressFiles == 1 && !decRunResult[0].endsWith(".gz")) {
+ gzipFile(decRunResult[0] + ".PRO.it" + iteration);
+ }
+ } // if (!fake)
+ }
+
+ // ------------- end of saving .pro.it* files ---------------
+
+ int[] candCount = new int[numSentences];
+ int[] lastUsedIndex = new int[numSentences];
+
+ ConcurrentHashMap[] suffStats_array = new ConcurrentHashMap[numSentences];
+ for (int i = 0; i < numSentences; ++i) {
+ candCount[i] = 0;
+ lastUsedIndex[i] = -1;
+ // suffStats_array[i].clear();
+ suffStats_array[i] = new ConcurrentHashMap<>();
+ }
+
+ // initLambda[0] is not used!
+ double[] initialLambda = new double[1 + numParams];
+ for (int i = 1; i <= numParams; ++i)
+ initialLambda[i] = lambda.get(i);
+
+ // the "score" in initialScore refers to that
+ // assigned by the evaluation metric)
+
+ // you may consider all candidates from iter 1, or from iter (iteration-prevIts) to current
+ // iteration
+ int firstIt = Math.max(1, iteration - prevIts);
+ // i.e. only process candidates from the current iteration and candidates
+ // from up to prevIts previous iterations.
+ println("Reading candidate translations from iterations " + firstIt + "-" + iteration, 1);
+ println("(and computing " + metricName
+ + " sufficient statistics for previously unseen candidates)", 1);
+ print(" Progress: ");
+
+ int[] newCandidatesAdded = new int[1 + iteration];
+ for (int it = 1; it <= iteration; ++it)
+ newCandidatesAdded[it] = 0;
+
+ try {
+ // read temp files from all past iterations
+ // 3 types of temp files:
+ // 1. output hypo at iter i
+ // 2. feature value of each hypo at iter i
+ // 3. suff stats of each hypo at iter i
+
+ // each inFile corresponds to the output of an iteration
+ // (index 0 is not used; no corresponding index for the current iteration)
+ BufferedReader[] inFile_sents = new BufferedReader[iteration];
+ BufferedReader[] inFile_feats = new BufferedReader[iteration];
+ BufferedReader[] inFile_stats = new BufferedReader[iteration];
+
+ // temp file(array) from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ InputStream inStream_sents, inStream_feats, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_feats = new FileInputStream(tmpDirPrefix + "temp.feats.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_feats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.feats.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_feats[it] = new BufferedReader(new InputStreamReader(inStream_feats, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ InputStream inStream_sentsCurrIt, inStream_featsCurrIt, inStream_statsCurrIt;
+ // temp file for current iteration!
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ inStream_featsCurrIt = new FileInputStream(tmpDirPrefix + "temp.feats.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ inStream_featsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.feats.it" + iteration + ".gz"));
+ }
+
+ BufferedReader inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_sentsCurrIt, "utf8"));
+ BufferedReader inFile_featsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_featsCurrIt, "utf8"));
+
+ BufferedReader inFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below
+ // is set to true
+ PrintWriter outFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below is
+ // set to false
+
+ // just to check if temp.stat.it.iteration exists
+ boolean statsCurrIt_exists = false;
+
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration)) {
+ inStream_statsCurrIt = new FileInputStream(tmpDirPrefix + "temp.stats.it" + iteration);
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration, tmpDirPrefix + "temp.stats.it"
+ + iteration + ".copy");
+ } else if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".gz")) {
+ inStream_statsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.stats.it" + iteration + ".gz"));
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz", tmpDirPrefix
+ + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ outFile_statsCurrIt = new PrintWriter(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // output the 4^th temp file: *.temp.stats.merged
+ PrintWriter outFile_statsMerged = new PrintWriter(tmpDirPrefix + "temp.stats.merged");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+ PrintWriter outFile_statsMergedKnown = new PrintWriter(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+
+ // output the 5^th 6^th temp file, but will be deleted at the end of the function
+ FileOutputStream outStream_unknownCands = new FileOutputStream(tmpDirPrefix
+ + "temp.currIt.unknownCands", false);
+ OutputStreamWriter outStreamWriter_unknownCands = new OutputStreamWriter(
+ outStream_unknownCands, "utf8");
+ BufferedWriter outFile_unknownCands = new BufferedWriter(outStreamWriter_unknownCands);
+
+ PrintWriter outFile_unknownIndices = new PrintWriter(tmpDirPrefix
+ + "temp.currIt.unknownIndices");
+
+ String sents_str, feats_str, stats_str;
+
+ // BUG: this assumes a candidate string cannot be produced for two
+ // different source sentences, which is not necessarily true
+ // (It's not actually a bug, but only because existingCandStats gets
+ // cleared before moving to the next source sentence.)
+ // FIX: should be made an array, indexed by i
+ HashMap<String, String> existingCandStats = new HashMap<String, String>();
+ // VERY IMPORTANT:
+ // A CANDIDATE X MAY APPEARED IN ITER 1, ITER 3
+ // BUT IF THE USER SPECIFIED TO CONSIDER ITERATIONS FROM ONLY ITER 2, THEN
+ // X IS NOT A "REPEATED" CANDIDATE IN ITER 3. THEREFORE WE WANT TO KEEP THE
+ // SUFF STATS FOR EACH CANDIDATE(TO SAVE COMPUTATION IN THE FUTURE)
+
+ // Stores precalculated sufficient statistics for candidates, in case
+ // the same candidate is seen again. (SS stored as a String.)
+ // Q: Why do we care? If we see the same candidate again, aren't we going
+ // to ignore it? So, why do we care about the SS of this repeat candidate?
+ // A: A "repeat" candidate may not be a repeat candidate in later
+ // iterations if the user specifies a value for prevMERTIterations
+ // that causes MERT to skip candidates from early iterations.
+
+ String[] featVal_str;
+
+ int totalCandidateCount = 0;
+
+ // new candidate size for each sentence
+ int[] sizeUnknown_currIt = new int[numSentences];
+
+ for (int i = 0; i < numSentences; ++i) {
+ // process candidates from previous iterations
+ // low efficiency? for each iteration, it reads in all previous iteration outputs
+ // therefore a lot of overlapping jobs
+ // this is an easy implementation to deal with the situation in which user only specified
+ // "previt" and hopes to consider only the previous previt
+ // iterations, then for each iteration the existing candadites will be different
+ for (int it = firstIt; it < iteration; ++it) {
+ // Why up to but *excluding* iteration?
+ // Because the last iteration is handled a little differently, since
+ // the SS must be calculated (and the corresponding file created),
+ // which is not true for previous iterations.
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // note that in all temp files, "||||||" is a separator between 2 n-best lists
+
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence, feature values,
+ // and sufficient statistics from the various temp files
+
+ // read one line of temp.sent, temp.feat, temp.stats from iteration it
+ sents_str = inFile_sents[it].readLine();
+ feats_str = inFile_feats[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1; // move on to the next n-best list
+ } else if (!existingCandStats.containsKey(sents_str)) // if this candidate does not
+ // exist
+ {
+ outFile_statsMergedKnown.println(stats_str);
+
+ // save feats & stats
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ // extract feature value
+ featVal_str = feats_str.split("\\s+");
+
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair : featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
+ int featId = Vocabulary.id(name);
+ // need to identify newly fired feats here
+ if (featId > numParams) {
+ ++numParams;
+ lambda.add(new Double(0));
+ }
+ }
+ }
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+ newCandidatesAdded[it] += 1;
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ outFile_statsMergedKnown.println("||||||");
+
+ // ---------- end of processing previous iterations ----------
+ // ---------- now start processing new candidates ----------
+
+ // now process the candidates of the current iteration
+ // now determine the new candidates of the current iteration
+
+ /*
+ * remember: BufferedReader inFile_sentsCurrIt BufferedReader inFile_featsCurrIt
+ * PrintWriter outFile_statsCurrIt
+ */
+
+ String[] sentsCurrIt_currSrcSent = new String[sizeOfNBest + 1];
+
+ Vector<String> unknownCands_V = new Vector<String>();
+ // which candidates (of the i'th source sentence) have not been seen before
+ // this iteration?
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence,
+ // and store it in the sentsCurrIt_currSrcSent array
+
+ sents_str = inFile_sentsCurrIt.readLine(); // read one candidate from the current
+ // iteration
+ sentsCurrIt_currSrcSent[n] = sents_str; // Note: possibly "||||||"
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ unknownCands_V.add(sents_str); // NEW CANDIDATE FROM THIS ITERATION
+ writeLine(sents_str, outFile_unknownCands);
+ outFile_unknownIndices.println(i); // INDEX OF THE NEW CANDIDATES
+ newCandidatesAdded[iteration] += 1;
+ existingCandStats.put(sents_str, "U"); // i.e. unknown
+ // we add sents_str to avoid duplicate entries in unknownCands_V
+ }
+ } // for (n)
+
+ // only compute suff stats for new candidates
+ // now unknownCands_V has the candidates for which we need to calculate
+ // sufficient statistics (for the i'th source sentence)
+ int sizeUnknown = unknownCands_V.size();
+ sizeUnknown_currIt[i] = sizeUnknown;
+
+ existingCandStats.clear();
+
+ } // for (i) each sentence
+
+ // ---------- end of merging candidates stats from previous iterations
+ // and finding new candidates ------------
+
+ /*
+ * int[][] newSuffStats = null; if (!statsCurrIt_exists && sizeUnknown > 0) { newSuffStats =
+ * evalMetric.suffStats(unknownCands, indices); }
+ */
+
+ outFile_statsMergedKnown.close();
+ outFile_unknownCands.close();
+ outFile_unknownIndices.close();
+
+ // want to re-open all temp files and start from scratch again?
+ for (int it = firstIt; it < iteration; ++it) // previous iterations temp files
+ {
+ inFile_sents[it].close();
+ inFile_stats[it].close();
+
+ InputStream inStream_sents, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ inFile_sentsCurrIt.close();
+ // current iteration temp files
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ }
+ inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(inStream_sentsCurrIt, "utf8"));
+
+ // calculate SS for unseen candidates and write them to file
+ FileInputStream inStream_statsCurrIt_unknown = null;
+ BufferedReader inFile_statsCurrIt_unknown = null;
+
+ if (!statsCurrIt_exists && newCandidatesAdded[iteration] > 0) {
+ // create the file...
+ evalMetric.createSuffStatsFile(tmpDirPrefix + "temp.currIt.unknownCands", tmpDirPrefix
+ + "temp.currIt.unknownIndices", tmpDirPrefix + "temp.stats.unknown", sizeOfNBest);
+
+ // ...and open it
+ inStream_statsCurrIt_unknown = new FileInputStream(tmpDirPrefix + "temp.stats.unknown");
+ inFile_statsCurrIt_unknown = new BufferedReader(new InputStreamReader(
+ inStream_statsCurrIt_unknown, "utf8"));
+ }
+
+ // open mergedKnown file
+ // newly created by the big loop above
+ FileInputStream instream_statsMergedKnown = new FileInputStream(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ BufferedReader inFile_statsMergedKnown = new BufferedReader(new InputStreamReader(
+ instream_statsMergedKnown, "utf8"));
+
+ // num of features before observing new firing features from this iteration
+ numParamsOld = numParams;
+
+ for (int i = 0; i < numSentences; ++i) {
+ // reprocess candidates from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sents[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ existingCandStats.put(sents_str, stats_str);
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ // copy relevant portion from mergedKnown to the merged file
+ String line_mergedKnown = inFile_statsMergedKnown.readLine();
+ while (!line_mergedKnown.equals("||||||")) {
+ outFile_statsMerged.println(line_mergedKnown);
+ line_mergedKnown = inFile_statsMergedKnown.readLine();
+ }
+
+ int[] stats = new int[suffStatsCount];
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sentsCurrIt.readLine();
+ feats_str = inFile_featsCurrIt.readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+
+ if (!statsCurrIt_exists) {
+ stats_str = inFile_statsCurrIt_unknown.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+
+ outFile_statsCurrIt.println(stats_str);
+ } else {
+ stats_str = inFile_statsCurrIt.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+ }
+
+ outFile_statsMerged.println(stats_str);
+
+ // save feats & stats
+ // System.out.println(sents_str+" "+feats_str);
+
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ featVal_str = feats_str.split("\\s+");
+
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair : featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
+ int featId = Vocabulary.id(name);
+ // need to identify newly fired feats here
+ if (featId > numParams) {
+ ++numParams;
+ lambda.add(new Double(0));
+ }
+ }
+ }
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+
+ // newCandidatesAdded[iteration] += 1;
+ // moved to code above detecting new candidates
+ } else {
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else {
+ // write SS to outFile_statsCurrIt
+ stats_str = existingCandStats.get(sents_str);
+ outFile_statsCurrIt.println(stats_str);
+ }
+ }
+
+ } // for (n)
+
+ // now d = sizeUnknown_currIt[i] - 1
+
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else
+ outFile_statsCurrIt.println("||||||");
+
+ existingCandStats.clear();
+ totalCandidateCount += candCount[i];
+
+ // output sentence progress
+ if ((i + 1) % 500 == 0) {
+ print((i + 1) + "\n" + " ", 1);
+ } else if ((i + 1) % 100 == 0) {
+ print("+", 1);
+ } else if ((i + 1) % 25 == 0) {
+ print(".", 1);
+ }
+
+ } // for (i)
+
+ inFile_statsMergedKnown.close();
+ outFile_statsMerged.close();
+
+ // for testing
+ /*
+ * int total_sent = 0; for( int i=0; i<numSentences; i++ ) {
+ * System.out.println(feat_hash[i].size()+" "+candCount[i]); total_sent +=
+ * feat_hash[i].size(); feat_hash[i].clear(); }
+ * System.out.println("----------------total sent: "+total_sent); total_sent = 0; for( int
+ * i=0; i<numSentences; i++ ) { System.out.println(stats_hash[i].size()+" "+candCount[i]);
+ * total_sent += stats_hash[i].size(); stats_hash[i].clear(); }
+ * System.out.println("*****************total sent: "+total_sent);
+ */
+
+ println("", 1); // finish progress line
+
+ for (int it = firstIt; it < iteration; ++it) {
+ inFile_sents[it].close();
+ inFile_feats[it].close();
+ inFile_stats[it].close();
+ }
+
+ inFile_sentsCurrIt.close();
+ inFile_featsCurrIt.close();
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.close();
+ else
+ outFile_statsCurrIt.close();
+
+ if (compressFiles == 1 && !statsCurrIt_exists) {
+ gzipFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // clear temp files
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownCands");
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownIndices");
+ deleteFile(tmpDirPrefix + "temp.stats.unknown");
+ deleteFile(tmpDirPrefix + "temp.stats.mergedKnown");
+
+ // cleanupMemory();
+
+ println("Processed " + totalCandidateCount + " distinct candidates " + "(about "
+ + totalCandidateCount / numSentences + " per sentence):", 1);
+ for (int it = firstIt; it <= iteration; ++it) {
+ println("newCandidatesAdded[it=" + it + "] = " + newCandidatesAdded[it] + " (about "
+ + newCandidatesAdded[it] / numSentences + " per sentence)", 1);
+ }
+
+ println("", 1);
+
+ println("Number of features observed so far: " + numParams);
+ println("", 1);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // n-best list converges
+ if (newCandidatesAdded[iteration] == 0) {
+ if (!oneModificationPerIteration) {
+ println("No new candidates added in this iteration; exiting PRO.", 1);
+ println("", 1);
+ println("--- PRO iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ if (returnBest) {
+ // note that bestLambda.size() <= lambda.size()
+ for (int p = 1; p < bestLambda.size(); ++p)
+ lambda.set(p, bestLambda.get(p));
+ // and set the rest of lambda to be 0
+ for (int p = 0; p < lambda.size() - bestLambda.size(); ++p)
+ lambda.set(p + bestLambda.size(), new Double(0));
+ }
+
+ return null; // this means that the old values should be kept by the caller
+ } else {
+ println("Note: No new candidates added in this iteration.", 1);
+ }
+ }
+
+ /************* start optimization **************/
+
+ /*
+ * for( int v=1; v<initialLambda[1].length; v++ ) System.out.print(initialLambda[1][v]+" ");
+ * System.exit(0);
+ */
+
+ Vector<String> output = new Vector<String>();
+
+ // note: initialLambda[] has length = numParamsOld
+ // augmented with new feature weights, initial values are 0
+ double[] initialLambdaNew = new double[1 + numParams];
+ System.arraycopy(initialLambda, 1, initialLambdaNew, 1, numParamsOld);
+
+ // finalLambda[] has length = numParams (considering new features)
+ double[] finalLambda = new double[1 + numParams];
+
+ Optimizer opt = new Optimizer(seed + iteration, isOptimizable, output, initialLambdaNew,
+ feat_hash, stats_hash, evalMetric, Tau, Xi, metricDiff, normalizationOptions,
+ classifierAlg, classifierParams);
+ finalLambda = opt.run_Optimizer();
+
+ if (returnBest) {
+ double metricScore = opt.getMetricScore();
+ if (!evalMetric.getToBeMinimized()) {
+ if (metricScore > prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ } else {
+ if (metricScore < prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ }
+ }
+
+ // System.out.println(finalLambda.length);
+ // for( int i=0; i<finalLambda.length-1; i++ )
+ // System.out.print(finalLambda[i+1]+" ");
+ // System.out.println();
+
+ /************* end optimization **************/
+
+ for (int i = 0; i < output.size(); i++)
+ println(output.get(i));
+
+ // check if any parameter has been updated
+ boolean anyParamChanged = false;
+ boolean anyParamChangedSignificantly = false;
+
+ for (int c = 1; c <= numParams; ++c) {
+ if (finalLambda[c] != lambda.get(c)) {
+ anyParamChanged = true;
+ }
+ if (Math.abs(finalLambda[c] - lambda.get(c)) > stopSigValue) {
+ anyParamChangedSignificantly = true;
+ }
+ }
+
+ // System.arraycopy(finalLambda,1,lambda,1,numParams);
+
+ println("--- PRO iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+
+ if (!anyParamChanged) {
+ println("No parameter value changed in this iteration; exiting PRO.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // was an early stopping criterion satisfied?
+ boolean critSatisfied = false;
+ if (!anyParamChangedSignificantly && stopSigValue >= 0) {
+ println("Note: No parameter value changed significantly " + "(i.e. by more than "
+ + stopSigValue + ") in this iteration.", 1);
+ critSatisfied = true;
+ }
+
+ if (critSatisfied) {
+ ++earlyStop;
+ println("", 1);
+ } else {
+ earlyStop = 0;
+ }
+
+ // if min number of iterations executed, investigate if early exit should happen
+ if (iteration >= minIts && earlyStop >= stopMinIts) {
+ println("Some early stopping criteria has been observed " + "in " + stopMinIts
+ + " consecutive iterations; exiting PRO.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ // note that numParams >= bestLamba.size()-1 here!
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // if max number of iterations executed, exit
+ if (iteration >= maxIts) {
+ println("Maximum number of PRO iterations reached; exiting PRO.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ // note that numParams >= bestLamba.size()-1 here!
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop
+ }
+
+ // use the new wt vector to decode the next iteration
+ // (interpolation with previous wt vector)
+ for (int i = 1; i <= numParams; i++)
+ lambda.set(i, interCoef * finalLambda[i] + (1 - interCoef) * lambda.get(i).doubleValue());
+
+ println("Next iteration will decode with lambda: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ // printMemoryUsage();
+ for (int i = 0; i < numSentences; ++i) {
+ suffStats_array[i].clear();
+ }
+ // cleanupMemory();
+ // println("",2);
+
+ retA[2] = 0; // i.e. this should NOT be the last iteration
+ done = true;
+
+ } // while (!done) // NOTE: this "loop" will only be carried out once
+
+ // delete .temp.stats.merged file, since it is not needed in the next
+ // iteration (it will be recreated from scratch)
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ retA[0] = FINAL_score;
+ retA[1] = earlyStop;
+ return retA;
+
+ } // run_single_iteration
+
+ private String lambdaToString(ArrayList<Double> lambdaA) {
+ String retStr = "{";
+ int featToPrint = numParams > 15 ? 15 : numParams;
+ // print at most the first 15 features
+
+ retStr += "(listing the first " + featToPrint + " lambdas)";
+ for (int c = 1; c <= featToPrint - 1; ++c) {
+ retStr += "" + String.format("%.4f", lambdaA.get(c).doubleValue()) + ", ";
+ }
+ retStr += "" + String.format("%.4f", lambdaA.get(numParams).doubleValue()) + "}";
+
+ return retStr;
+ }
+
+ private String[] run_decoder(int iteration) {
+ String[] retSA = new String[2];
+
+ // retsa saves the output file name(nbest-file)
+ // and the decoder type
+
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ // use fake decoder
+ if (fakeFileNameTemplate != null
+ && fileExists(fakeFileNamePrefix + iteration + fakeFileNameSuffix)) {
+ String fakeFileName = fakeFileNamePrefix + iteration + fakeFileNameSuffix;
+ println("Not running decoder; using " + fakeFileName + " instead.", 1);
+ /*
+ * if (fakeFileName.endsWith(".gz")) { copyFile(fakeFileName,decoderOutFileName+".gz");
+ * gunzipFile(decoderOutFileName+".gz"); } else { copyFile(fakeFileName,decoderOutFileName); }
+ */
+ retSA[0] = fakeFileName;
+ retSA[1] = "2";
+
+ } else {
+ println("Running external decoder...", 1);
+
+ try {
+ ArrayList<String> cmd = new ArrayList<String>();
+ cmd.add(decoderCommandFileName);
+
+ if (passIterationToDecoder)
+ cmd.add(Integer.toString(iteration));
+
+ ProcessBuilder pb = new ProcessBuilder(cmd);
+ // this merges the error and output streams of the subprocess
+ pb.redirectErrorStream(true);
+ Process p = pb.start();
+
+ // capture the sub-command's output
+ new StreamGobbler(p.getInputStream(), decVerbosity).start();
+
+ int decStatus = p.waitFor();
+ if (decStatus != validDecoderExitValue) {
+ throw new RuntimeException("Call to decoder returned " + decStatus + "; was expecting "
+ + validDecoderExitValue + ".");
+ }
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ retSA[0] = decoderOutFileName;
+ retSA[1] = "1";
+
+ }
+
+ return retSA;
+ }
+
+ private void produceTempFiles(String nbestFileName, int iteration) {
+ try {
+ String sentsFileName = tmpDirPrefix + "temp.sents.it" + iteration;
+ String featsFileName = tmpDirPrefix + "temp.feats.it" + iteration;
+
+ FileOutputStream outStream_sents = new FileOutputStream(sentsFileName, false);
+ OutputStreamWriter outStreamWriter_sents = new OutputStreamWriter(outStream_sents, "utf8");
+ BufferedWriter outFile_sents = new BufferedWriter(outStreamWriter_sents);
+
+ PrintWriter outFile_feats = new PrintWriter(featsFileName);
+
+ InputStream inStream_nbest = null;
+ if (nbestFileName.endsWith(".gz")) {
+ inStream_nbest = new GZIPInputStream(new FileInputStream(nbestFileName));
+ } else {
+ inStream_nbest = new FileInputStream(nbestFileName);
+ }
+ BufferedReader inFile_nbest = new BufferedReader(
+ new InputStreamReader(inStream_nbest, "utf8"));
+
+ String line; // , prevLine;
+ String candidate_str = "";
+ String feats_str = "";
+
+ int i = 0;
+ int n = 0;
+ line = inFile_nbest.readLine();
+
+ while (line != null) {
+
+ /*
+ * line format:
+ *
+ * i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
+ * .*
+ */
+
+ // in a well formed file, we'd find the nth candidate for the ith sentence
+
+ int read_i = Integer.parseInt((line.substring(0, line.indexOf("|||"))).trim());
+
+ if (read_i != i) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = (line.substring(line.indexOf("|||") + 3)).trim(); // get rid of initial text
+
+ candidate_str = (line.substring(0, line.indexOf("|||"))).trim();
+ feats_str = (line.substring(line.indexOf("|||") + 3)).trim();
+ // get rid of candidate string
+
+ int junk_i = feats_str.indexOf("|||");
+ if (junk_i >= 0) {
+ feats_str = (feats_str.substring(0, junk_i)).trim();
+ }
+
+ writeLine(normalize(candidate_str, textNormMethod), outFile_sents);
+ outFile_feats.println(feats_str);
+
+ ++n;
+ if (n == sizeOfNBest) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = inFile_nbest.readLine();
+ }
+
+ if (i != numSentences) { // last sentence had too few candidates
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ }
+
+ inFile_nbest.close();
+ outFile_sents.close();
+ outFile_feats.close();
+
+ if (compressFiles == 1) {
+ gzipFile(sentsFileName);
+ gzipFile(featsFileName);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ private void createConfigFile(ArrayList<Double> params, String cfgFileName,
+ String templateFileName) {
+ try {
+ // i.e. create cfgFileName, which is similar to templateFileName, but with
+ // params[] as parameter values
+
+ BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
+ PrintWriter outFile = new PrintWriter(cfgFileName);
+
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
+ int origFeatNum = 0; // feat num in the template file
+
+ String line = inFile.readLine();
+ while (line != null) {
+ int c_match = -1;
+ for (int c = 1; c <= numParams; ++c) {
+ if (line.startsWith(Vocabulary.word(c) + " ")) {
+ c_match = c;
+ ++origFeatNum;
+ break;
+ }
+ }
+
+ if (c_match == -1) {
+ outFile.println(line);
+ } else {
+ if (Math.abs(params.get(c_match).doubleValue()) > 1e-20)
+ outFile.println(Vocabulary.word(c_match) + " " + params.get(c_match));
+ }
+
+ line = inFile.readLine();
+ }
+
+ // now append weights of new features
+ for (int c = origFeatNum + 1; c <= numParams; ++c) {
+ if (Math.abs(params.get(c).doubleValue()) > 1e-20)
+ outFile.println(Vocabulary.word(c) + " " + params.get(c));
+ }
+
+ inFile.close();
+ outFile.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void processParamFile() {
+ // process parameter file
+ Scanner inFile_init = null;
+ try {
+ inFile_init = new Scanner(new FileReader(paramsFileName));
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ String dummy = "";
+
+ // initialize lambda[] and other related arrays
+ for (int c = 1; c <= numParams; ++c) {
+ // skip parameter name
+ while (!dummy.equals("|||")) {
+ dummy = inFile_init.next();
+ }
+
+ // read default value
+ lambda.set(c, inFile_init.nextDouble());
+ defaultLambda[c] = lambda.get(c).doubleValue();
+
+ // read isOptimizable
+ dummy = inFile_init.next();
+ if (dummy.equals("Opt")) {
+ isOptimizable[c] = true;
+ } else if (dummy.equals("Fix")) {
+ isOptimizable[c] = false;
+ } else {
+ throw new RuntimeException("Unknown isOptimizable string " + dummy + " (must be either Opt or Fix)");
+ }
+
+ if (!isOptimizable[c]) { // skip next two values
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ } else {
+ // the next two values are not used, only to be consistent with ZMERT's params file format
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ // set minRandValue[c] and maxRandValue[c] (range for random values)
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("minRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ minRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("maxRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ maxRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ // check for illogical values
+ if (minRandValue[c] > maxRandValue[c]) {
+ throw new RuntimeException("minRandValue[" + c + "]=" + minRandValue[c] + " > " + maxRandValue[c]
+ + "=maxRandValue[" + c + "]!");
+ }
+
+ // check for odd values
+ if (minRandValue[c] == maxRandValue[c]) {
+ println("Warning: lambda[" + c + "] has " + "minRandValue = maxRandValue = "
+ + minRandValue[c] + ".", 1);
+ }
+ } // if (!isOptimizable[c])
+
+ /*
+ * precision[c] = inFile_init.nextDouble(); if (precision[c] < 0) { println("precision[" + c +
+ * "]=" + precision[c] + " < 0! Must be non-negative."); System.exit(21); }
+ */
+
+ }
+
+ // set normalizationOptions[]
+ String origLine = "";
+ while (origLine != null && origLine.length() == 0) {
+ origLine = inFile_init.nextLine();
+ }
+
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ // normalization = none
+ // normalization = absval 1 lm
+ // normalization = maxabsval 1
+ // normalization = minabsval 1
+ // normalization = LNorm 2 1
+
+ dummy = (origLine.substring(origLine.indexOf("=") + 1)).trim();
+ String[] dummyA = dummy.split("\\s+");
+
+ if (dummyA[0].equals("none")) {
+ normalizationOptions[0] = 0;
+ } else if (dummyA[0].equals("absval")) {
+ normalizationOptions[0] = 1;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ String pName = dummyA[2];
+ for (int i = 3; i < dummyA.length; ++i) { // in case parameter name has multiple words
+ pName = pName + " " + dummyA[i];
+ }
+ normalizationOptions[2] = Vocabulary.id(pName);
+
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the absval normalization method must be positive.");
+ }
+ if (normalizationOptions[2] == 0) {
+ throw new RuntimeException("Unrecognized feature name " + normalizationOptions[2]
+ + " for absval normalization method.");
+ }
+ } else if (dummyA[0].equals("maxabsval")) {
+ normalizationOptions[0] = 2;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the maxabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("minabsval")) {
+ normalizationOptions[0] = 3;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the minabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("LNorm")) {
+ normalizationOptions[0] = 4;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ normalizationOptions[2] = Double.parseDouble(dummyA[2]);
+ if (normalizationOptions[1] <= 0 || normalizationOptions[2] <= 0) {
+ throw new RuntimeException("Both values for the LNorm normalization method must be positive.");
+ }
+ } else {
+ throw new RuntimeException("Unrecognized normalization method " + dummyA[0] + "; "
+ + "must be one of none, absval, maxabsval, and LNorm.");
+ } // if (dummyA[0])
+
+ inFile_init.close();
+ } // processParamFile()
+
+ private void processDocInfo() {
+ // sets numDocuments and docOfSentence[]
+ docOfSentence = new int[numSentences];
+
+ if (docInfoFileName == null) {
+ for (int i = 0; i < numSentences; ++i)
+ docOfSentence[i] = 0;
+ numDocuments = 1;
+ } else {
+
+ try {
+
+ // 4 possible formats:
+ // 1) List of numbers, one per document, indicating # sentences in each document.
+ // 2) List of "docName size" pairs, one per document, indicating name of document and #
+ // sentences.
+ // 3) List of docName's, one per sentence, indicating which doument each sentence belongs
+ // to.
+ // 4) List of docName_number's, one per sentence, indicating which doument each sentence
+ // belongs to,
+ // and its order in that document. (can also use '-' instead of '_')
+
- int docInfoSize = countNonEmptyLines(docInfoFileName);
++ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
+
+ if (docInfoSize < numSentences) { // format #1 or #2
+ numDocuments = docInfoSize;
+ int i = 0;
+
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ String line = inFile.readLine();
+ boolean format1 = (!(line.contains(" ")));
+
+ for (int doc = 0; doc < numDocuments; ++doc) {
+
+ if (doc != 0)
+ line = inFile.readLine();
+
+ int docSize = 0;
+ if (format1) {
+ docSize = Integer.parseInt(line);
+ } else {
+ docSize = Integer.parseInt(line.split("\\s+")[1]);
+ }
+
+ for (int i2 = 1; i2 <= docSize; ++i2) {
+ docOfSentence[i] = doc;
+ ++i;
+ }
+
+ }
+
+ // now i == numSentences
+
+ inFile.close();
+
+ } else if (docInfoSize == numSentences) { // format #3 or #4
+
+ boolean format3 = false;
+
+ HashSet<String> seenStrings = new HashSet<String>();
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ // set format3 = true if a duplicate is found
+ String line = inFile.readLine();
+ if (seenStrings.contains(line))
+ format3 = true;
+ seenStrings.add(line);
+ }
+
+ inFile.close();
+
+ HashSet<String> seenDocNames = new HashSet<String>();
+ HashMap<String, Integer> docOrder = new HashMap<String, Integer>();
+ // maps a document name to the order (0-indexed) in which it was seen
+
+ inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ String line = inFile.readLine();
+
+ String docName = "";
+ if (format3) {
+ docName = line;
+ } else {
+ int sep_i = Math.max(line.lastIndexOf('_'), line.lastIndexOf('-'));
+ docName = line.substring(0, sep_i);
+ }
+
+ if (!seenDocNames.contains(docName)) {
+ seenDocNames.add(docName);
+ docOrder.put(docName, seenDocNames.size() - 1);
+ }
+
+ int docOrder_i = docOrder.get(docName);
+
+ docOfSentence[i] = docOrder_i;
+
+ }
+
+ inFile.close();
+
+ numDocuments = seenDocNames.size();
+
+ } else { // badly formatted
+
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private boolean copyFile(String origFileName, String newFileName) {
+ try {
+ File inputFile = new File(origFileName);
+ File outputFile = new File(newFileName);
+
+ InputStream in = new FileInputStream(inputFile);
+ OutputStream out = new FileOutputStream(outputFile);
+
+ byte[] buffer = new byte[1024];
+ int len;
+ while ((len = in.read(buffer)) > 0) {
+ out.write(buffer, 0, len);
+ }
+ in.close();
+ out.close();
+
+ /*
+ * InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
+ * new BufferedReader(new InputStreamReader(inStream, "utf8"));
+ *
+ * FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
+ * outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
+ * BufferedWriter(outStreamWriter);
+ *
+ * String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
+ *
+ * inFile.close(); outFile.close();
+ */
+ return true;
+ } catch (IOException e) {
+ LOG.error(e.getMessage(), e);
+ return false;
+ }
+ }
+
+ private void renameFile(String origFileName, String newFileName) {
+ if (fileExists(origFileName)) {
+ deleteFile(newFileName);
+ File oldFile = new File(origFileName);
+ File newFile = new File(newFileName);
+ if (!oldFile.renameTo(newFile)) {
+ println("Warning: attempt to rename " + origFileName + " to " + newFileName
+ + " was unsuccessful!", 1);
+ }
+ } else {
+ println("Warning: file " + origFileName + " does not exist! (in PROCore.renameFile)", 1);
+ }
+ }
+
+ private void deleteFile(String fileName) {
+ if (fileExists(fileName)) {
+ File fd = new File(fileName);
+ if (!fd.delete()) {
+ println("Warning: attempt to delete " + fileName + " was unsuccessful!", 1);
+ }
+ }
+ }
+
+ private void writeLine(String line, BufferedWriter writer) throws IOException {
+ writer.write(line, 0, line.length());
+ writer.newLine();
+ writer.flush();
+ }
+
+ // need to re-write to handle different forms of lambda
+ public void finish() {
+ if (myDecoder != null) {
+ myDecoder.cleanUp();
+ }
+
+ // create config file with final values
+ createConfigFile(lambda, decoderConfigFileName + ".PRO.final", decoderConfigFileName
+ + ".PRO.orig");
+
+ // delete current decoder config file and decoder output
+ deleteFile(decoderConfigFileName);
+ deleteFile(decoderOutFileName);
+
+ // restore original name for config file (name was changed
+ // in initialize() so it doesn't get overwritten)
+ renameFile(decoderConfigFileName + ".PRO.orig", decoderConfigFileName);
+
+ if (finalLambdaFileName != null) {
+ try {
+ PrintWriter outFile_lambdas = new PrintWriter(finalLambdaFileName);
+ for (int c = 1; c <= numParams; ++c) {
+ outFile_lambdas.println(Vocabulary.word(c) + " ||| " + lambda.get(c).doubleValue());
+ }
+ outFile_lambdas.close();
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private String[] cfgFileToArgsArray(String fileName) {
+ checkFile(fileName);
+
+ Vector<String> argsVector = new Vector<String>();
+
- BufferedReader inFile = null;
- try {
- inFile = new BufferedReader(new FileReader(fileName));
++ try (BufferedReader inFile = new BufferedReader(new FileReader(fileName));) {
+ String line, origLine;
+ do {
+ line = inFile.readLine();
+ origLine = line; // for error reporting purposes
+
+ if (line != null && line.length() > 0 && line.charAt(0) != '#') {
+
+ if (line.indexOf("#") != -1) { // discard comment
+ line = line.substring(0, line.indexOf("#"));
+ }
+
+ line = line.trim();
+
+ // now line should look like "-xxx XXX"
+
+ /*
+ * OBSOLETE MODIFICATION //SPECIAL HANDLING FOR PRO CLASSIFIER PARAMETERS String[] paramA
+ * = line.split("\\s+");
+ *
+ * if( paramA[0].equals("-classifierParams") ) { String classifierParam = ""; for(int p=1;
+ * p<=paramA.length-1; p++) classifierParam += paramA[p]+" ";
+ *
+ * if(paramA.length>=2) { String[] tmpParamA = new String[2]; tmpParamA[0] = paramA[0];
+ * tmpParamA[1] = classifierParam; paramA = tmpParamA; } else {
+ * println("Malformed line in config file:"); println(origLine); System.exit(70); } }//END
+ * MODIFICATION
+ */
+
+ // CMU MODIFICATION(FROM METEOR FOR ZMERT)
+ // Parse args
+ ArrayList<String> argList = new ArrayList<String>();
+ StringBuilder arg = new StringBuilder();
+ boolean quoted = false;
+ for (int i = 0; i < line.length(); i++) {
+ if (Character.isWhitespace(line.charAt(i))) {
+ if (quoted)
+ arg.append(line.charAt(i));
+ else if (arg.length() > 0) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ } else if (line.charAt(i) == '\'') {
+ if (quoted) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ quoted = !quoted;
+ } else
+ arg.append(line.charAt(i));
+ }
+ if (arg.length() > 0)
+ argList.add(arg.toString());
+ // Create paramA
+ String[] paramA = new String[argList.size()];
+ for (int i = 0; i < paramA.length; paramA[i] = argList.get(i++))
+ ;
+ // END CMU MODIFICATION
+
+ if (paramA.length == 2 && paramA[0].charAt(0) == '-') {
+ argsVector.add(paramA[0]);
+ argsVector.add(paramA[1]);
+ } else if (paramA.length > 2 && (paramA[0].equals("-m") || paramA[0].equals("-docSet"))) {
+ // -m (metricName), -docSet are allowed to have extra optinos
+ for (int opt = 0; opt < paramA.length; ++opt) {
+ argsVector.add(paramA[opt]);
+ }
+ } else {
+ throw new RuntimeException("Malformed line in config file:" + origLine);
+ }
+
+ }
+ } while (line != null);
-
- inFile.close();
+ } catch (FileNotFoundException e) {
+ println("PRO configuration file " + fileName + " was not found!");
+ throw new RuntimeException(e);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ String[] argsArray = new String[argsVector.size()];
+
+ for (int i = 0; i < argsVector.size(); ++i) {
+ argsArray[i] = argsVector.elementAt(i);
+ }
+
+ return argsArray;
+ }
+
+ private void processArgsArray(String[] args) {
+ processArgsArray(args, true);
+ }
+
+ private void processArgsArray(String[] args, boolean firstTime) {
+ /* set default values */
+ // Relevant files
+ dirPrefix = null;
+ sourceFileName = null;
+ refFileName = "reference.txt";
+ refsPerSen = 1;
+ textNormMethod = 1;
+ paramsFileName = "params.txt";
+ docInfoFileName = null;
+ finalLambdaFileName = null;
+ // MERT specs
+ metricName = "BLEU";
+ metricName_display = metricName;
+ metricOptions = new String[2];
+ metricOptions[0] = "4";
+ metricOptions[1] = "closest";
+ docSubsetInfo = new int[7];
+ docSubsetInfo[0] = 0;
+ maxMERTIterations = 20;
+ prevMERTIterations = 20;
+ minMERTIterations = 5;
+ stopMinIts = 3;
+ stopSigValue = -1;
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ numOptThreads = 1;
+ saveInterFiles = 3;
+ compressFiles = 0;
+ oneModificationPerIteration = false;
+ randInit = false;
+ seed = System.currentTimeMillis();
+ // useDisk = 2;
+ // Decoder specs
+ decoderCommandFileName = null;
+ passIterationToDecoder = false;
+ decoderOutFileName = "output.nbest";
+ validDecoderExitValue = 0;
+ decoderConfigFileName = "dec_cfg.txt";
+ sizeOfNBest = 100;
+ fakeFileNameTemplate = null;
+ fakeFileNamePrefix = null;
+ fakeFileNameSuffix = null;
+ // Output specs
+ verbosity = 1;
+ decVerbosity = 0;
+
+ int i = 0;
+
+ while (i < args.length) {
+ String option = args[i];
+ // Relevant files
+ if (option.equals("-dir")) {
+ dirPrefix = args[i + 1];
+ } else if (option.equals("-s")) {
+ sourceFileName = args[i + 1];
+ } else if (option.equals("-r")) {
+ refFileName = args[i + 1];
+ } else if (option.equals("-rps")) {
+ refsPerSen = Integer.parseInt(args[i + 1]);
+ if (refsPerSen < 1) {
+ throw new RuntimeException("refsPerSen must be positive.");
+ }
+ } else if (option.equals("-txtNrm")) {
+ textNormMethod = Integer.parseInt(args[i + 1]);
+ if (textNormMethod < 0 || textNormMethod > 4) {
+ throw new RuntimeException("textNormMethod should be between 0 and 4");
+ }
+ } else if (option.equals("-p")) {
+ paramsFileName = args[i + 1];
+ } else if (option.equals("-docInfo")) {
+ docInfoFileName = args[i + 1];
+ } else if (option.equals("-fin")) {
+ finalLambdaFileName = args[i + 1];
+ // MERT specs
+ } else if (option.equals("-m")) {
+ metricName = args[i + 1];
+ metricName_display = metricName;
+ if (EvaluationMetric.knownMetricName(metricName)) {
+ int optionCount = EvaluationMetric.metricOptionCount(metricName);
+ metricOptions = new String[optionCount];
+ for (int opt = 0; opt < optionCount; ++opt) {
+ metricOptions[opt] = args[i + opt + 2];
+ }
+ i += optionCount;
+ } else {
+ throw new RuntimeException("Unknown metric name " + metricName + ".");
+ }
+ } else if (option.equals("-docSet")) {
+ String method = args[i + 1];
+
+ if (method.equals("all")) {
+ docSubsetInfo[0] = 0;
+ i += 0;
+ } else if (method.equals("bottom")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 1;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 2;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("top")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 3;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 4;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("window")) {
+ String a1 = args[i + 2];
+ a1 = a1.substring(0, a1.indexOf("d")); // size of window
+ String a2 = args[i + 4];
+ if (a2.indexOf("p") > 0) {
+ docSubsetInfo[0] = 5;
+ a2 = a2.substring(0, a2.indexOf("p"));
+ } else {
+ docSubsetInfo[0] = 6;
+ a2 = a2.substring(0, a2.indexOf("r"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a1);
+ docSubsetInfo[6] = Integer.parseInt(a2);
+ i += 3;
+ } else {
+ throw new RuntimeException("Unknown docSet method " + method + ".");
+ }
+ } else if (option.equals("-maxIt")) {
+ maxMERTIterations = Integer.parseInt(args[i + 1]);
+ if (maxMERTIterations < 1) {
+ throw new RuntimeException("maxMERTIts must be positive.");
+ }
+ } else if (option.equals("-minIt")) {
+ minMERTIterations = Integer.parseInt(args[i + 1]);
+ if (minMERTIterations < 1) {
+ throw new RuntimeException("minMERTIts must be positive.");
+ }
+ } else if (option.equals("-prevIt")) {
+ prevMERTIterations = Integer.parseInt(args[i + 1]);
+ if (prevMERTIterations < 0) {
+ throw new RuntimeException("prevMERTIts must be non-negative.");
+ }
+ } else if (option.equals("-stopIt")) {
+ stopMinIts = Integer.parseInt(args[i + 1]);
+ if (stopMinIts < 1) {
+ throw new RuntimeException("stopMinIts must be positive.");
+ }
+ } else if (option.equals("-stopSig")) {
+ stopSigValue = Double.parseDouble(args[i + 1]);
+ }
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ else if (option.equals("-thrCnt")) {
+ numOptThreads = Integer.parseInt(args[i + 1]);
+ if (numOptThreads < 1) {
+
<TRUNCATED>
[13/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
index 2424a1e,0000000..e8242f6
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
@@@ -1,126 -1,0 +1,126 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.tm;
+
+import static org.apache.joshua.decoder.ff.tm.packed.PackedGrammar.VOCABULARY_FILENAME;
+import static org.apache.joshua.util.FormatUtils.cleanNonTerminal;
+import static org.apache.joshua.util.FormatUtils.isNonterminal;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.util.io.LineReader;
-
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class CreateGlueGrammar {
+
+
+ private static final Logger LOG = LoggerFactory.getLogger(CreateGlueGrammar.class);
+
+ private final Set<String> nonTerminalSymbols = new HashSet<>();
+
+ @Option(name = "--grammar", aliases = {"-g"}, required = true, usage = "provide grammar to determine list of NonTerminal symbols.")
+ private String grammarPath;
-
++
+ @Option(name = "--goal", aliases = {"-goal"}, required = false, usage = "specify custom GOAL symbol. Default: 'GOAL'")
+ private final String goalSymbol = cleanNonTerminal(new JoshuaConfiguration().goal_symbol);
+
+ /* Rule templates */
+ // [GOAL] ||| <s> ||| <s> ||| 0
+ private static final String R_START = "[%1$s] ||| <s> ||| <s> ||| 0";
+ // [GOAL] ||| [GOAL,1] [X,2] ||| [GOAL,1] [X,2] ||| -1
+ private static final String R_TWO = "[%1$s] ||| [%1$s,1] [%2$s,2] ||| [%1$s,1] [%2$s,2] ||| -1";
+ // [GOAL] ||| [GOAL,1] </s> ||| [GOAL,1] </s> ||| 0
+ private static final String R_END = "[%1$s] ||| [%1$s,1] </s> ||| [%1$s,1] </s> ||| 0";
+ // [GOAL] ||| <s> [X,1] </s> ||| <s> [X,1] </s> ||| 0
+ private static final String R_TOP = "[%1$s] ||| <s> [%2$s,1] </s> ||| <s> [%2$s,1] </s> ||| 0";
-
++
+ private void run() throws IOException {
-
++
+ File grammar_file = new File(grammarPath);
+ if (!grammar_file.exists()) {
+ throw new IOException("Grammar file doesn't exist: " + grammarPath);
+ }
+
+ // in case of a packedGrammar, we read the serialized vocabulary,
+ // collecting all cleaned nonTerminal symbols.
+ if (grammar_file.isDirectory()) {
+ Vocabulary.read(new File(grammarPath + File.separator + VOCABULARY_FILENAME));
+ for (int i = 0; i < Vocabulary.size(); ++i) {
+ final String token = Vocabulary.word(i);
+ if (isNonterminal(token)) {
+ nonTerminalSymbols.add(cleanNonTerminal(token));
+ }
+ }
+ // otherwise we collect cleaned left-hand sides from the rules in the text grammar.
- } else {
- final LineReader reader = new LineReader(grammarPath);
- while (reader.hasNext()) {
- final String line = reader.next();
- int lhsStart = line.indexOf("[") + 1;
- int lhsEnd = line.indexOf("]");
- if (lhsStart < 1 || lhsEnd < 0) {
- LOG.info("malformed rule: {}\n", line);
- continue;
++ } else {
++ try (final LineReader reader = new LineReader(grammarPath);) {
++ while (reader.hasNext()) {
++ final String line = reader.next();
++ int lhsStart = line.indexOf("[") + 1;
++ int lhsEnd = line.indexOf("]");
++ if (lhsStart < 1 || lhsEnd < 0) {
++ LOG.info("malformed rule: {}\n", line);
++ continue;
++ }
++ final String lhs = line.substring(lhsStart, lhsEnd);
++ nonTerminalSymbols.add(lhs);
+ }
- final String lhs = line.substring(lhsStart, lhsEnd);
- nonTerminalSymbols.add(lhs);
+ }
+ }
-
++
+ LOG.info("{} nonTerminal symbols read: {}", nonTerminalSymbols.size(),
+ nonTerminalSymbols.toString());
+
+ // write glue rules to stdout
-
++
+ System.out.println(String.format(R_START, goalSymbol));
-
++
+ for (String nt : nonTerminalSymbols)
+ System.out.println(String.format(R_TWO, goalSymbol, nt));
-
++
+ System.out.println(String.format(R_END, goalSymbol));
-
++
+ for (String nt : nonTerminalSymbols)
+ System.out.println(String.format(R_TOP, goalSymbol, nt));
+
+ }
-
++
+ public static void main(String[] args) throws IOException {
+ final CreateGlueGrammar glueCreator = new CreateGlueGrammar();
+ final CmdLineParser parser = new CmdLineParser(glueCreator);
+
+ try {
+ parser.parseArgument(args);
+ glueCreator.run();
+ } catch (CmdLineException e) {
+ LOG.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
index 2eb7e6f,0000000..bacd294
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/tm/packed/PackedGrammar.java
@@@ -1,997 -1,0 +1,996 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.tm.packed;
+
+/***
+ * This package implements Joshua's packed grammar structure, which enables the efficient loading
+ * and accessing of grammars. It is described in the paper:
- *
++ *
+ * @article{ganitkevitch2012joshua,
+ * Author = {Ganitkevitch, J. and Cao, Y. and Weese, J. and Post, M. and Callison-Burch, C.},
+ * Journal = {Proceedings of WMT12},
+ * Title = {Joshua 4.0: Packing, PRO, and paraphrases},
+ * Year = {2012}}
- *
++ *
+ * The packed grammar works by compiling out the grammar tries into a compact format that is loaded
+ * and parsed directly from Java arrays. A fundamental problem is that Java arrays are indexed
+ * by ints and not longs, meaning the maximum size of the packed grammar is about 2 GB. This forces
+ * the use of packed grammar slices, which together constitute the grammar. The figure in the
- * paper above shows what each slice looks like.
- *
++ * paper above shows what each slice looks like.
++ *
+ * The division across slices is done in a depth-first manner. Consider the entire grammar organized
+ * into a single source-side trie. The splits across tries are done by grouping the root-level
- * outgoing trie arcs --- and the entire trie beneath them --- across slices.
- *
- * This presents a problem: if the subtree rooted beneath a single top-level arc is too big for a
++ * outgoing trie arcs --- and the entire trie beneath them --- across slices.
++ *
++ * This presents a problem: if the subtree rooted beneath a single top-level arc is too big for a
+ * slice, the grammar can't be packed. This happens with very large Hiero grammars, for example,
+ * where there are a *lot* of rules that start with [X].
- *
++ *
+ * A solution being worked on is to split that symbol and pack them into separate grammars with a
+ * shared vocabulary, and then rely on Joshua's ability to query multiple grammars for rules to
+ * solve this problem. This is not currently implemented but could be done directly in the
+ * Grammar Packer.
+ *
+ * *UPDATE 10/2015*
+ * The introduction of a SliceAggregatingTrie together with sorting the grammar by the full source string
+ * (not just by the first source word) allows distributing rules with the same first source word
+ * across multiple slices.
+ * @author fhieber
+ */
+
+import static java.util.Collections.sort;
+import static org.apache.joshua.decoder.ff.FeatureMap.getFeature;
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.BufferUnderflowException;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
- import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.channels.FileChannel.MapMode;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.security.DigestInputStream;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
+import java.util.Arrays;
- import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.tm.AbstractGrammar;
+import org.apache.joshua.decoder.ff.tm.BasicRuleCollection;
+import org.apache.joshua.decoder.ff.tm.OwnerId;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleCollection;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.ff.tm.hash_based.ExtensionIterator;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.encoding.EncoderConfiguration;
+import org.apache.joshua.util.encoding.FloatEncoder;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+
+public class PackedGrammar extends AbstractGrammar {
+
+ private static final Logger LOG = LoggerFactory.getLogger(PackedGrammar.class);
+ public static final String VOCABULARY_FILENAME = "vocabulary";
+
+ private EncoderConfiguration encoding;
+ private PackedRoot root;
+ private ArrayList<PackedSlice> slices;
+
+ private final File vocabFile; // store path to vocabulary file
+
+ // The version number of the earliest supported grammar packer
+ public static final int SUPPORTED_VERSION = 3;
+
+ // A rule cache for commonly used tries to avoid excess object allocations
+ // Testing shows there's up to ~95% hit rate when cache size is 5000 Trie nodes.
+ private final Cache<Trie, List<Rule>> cached_rules;
+
+ private final String grammarDir;
+
+ public PackedGrammar(String grammar_dir, int span_limit, String owner, String type,
+ JoshuaConfiguration joshuaConfiguration) throws IOException {
+ super(owner, joshuaConfiguration, span_limit);
+
+ this.grammarDir = grammar_dir;
+
+ // Read the vocabulary.
+ vocabFile = new File(grammar_dir + File.separator + VOCABULARY_FILENAME);
+ LOG.info("Reading vocabulary: {}", vocabFile);
+ if (!Vocabulary.read(vocabFile)) {
+ throw new RuntimeException("mismatches or collisions while reading on-disk vocabulary");
+ }
-
++
+ // Read the config
+ String configFile = grammar_dir + File.separator + "config";
+ if (new File(configFile).exists()) {
+ LOG.info("Reading packed config: {}", configFile);
+ readConfig(configFile);
+ }
-
++
+ // Read the quantizer setup.
+ LOG.info("Reading encoder configuration: {}{}encoding", grammar_dir, File.separator);
+ encoding = new EncoderConfiguration();
+ encoding.load(grammar_dir + File.separator + "encoding");
+
+ final List<String> listing = Arrays.asList(new File(grammar_dir).list());
+ sort(listing); // File.list() has arbitrary sort order
+ slices = new ArrayList<>();
+ for (String prefix : listing) {
+ if (prefix.startsWith("slice_") && prefix.endsWith(".source"))
+ slices.add(new PackedSlice(grammar_dir + File.separator + prefix.substring(0, 11)));
+ }
+
+ long count = 0;
+ for (PackedSlice s : slices)
+ count += s.estimated.length;
+ root = new PackedRoot(slices);
+ cached_rules = CacheBuilder.newBuilder().maximumSize(joshuaConfiguration.cachedRuleSize).build();
+
+ LOG.info("Loaded {} rules", count);
+ }
+
+ @Override
+ public Trie getTrieRoot() {
+ return root;
+ }
+
+ @Override
+ public boolean hasRuleForSpan(int startIndex, int endIndex, int pathLength) {
+ return (spanLimit == -1 || pathLength <= spanLimit);
+ }
+
+ @Override
+ public int getNumRules() {
+ int num_rules = 0;
+ for (PackedSlice ps : slices)
+ num_rules += ps.featureSize;
+ return num_rules;
+ }
+
+ /**
+ * Computes the MD5 checksum of the vocabulary file.
+ * Can be used for comparing vocabularies across multiple packedGrammars.
+ * @return the computed checksum
+ */
+ public String computeVocabularyChecksum() {
+ MessageDigest md;
+ try {
+ md = MessageDigest.getInstance("MD5");
+ } catch (NoSuchAlgorithmException e) {
+ throw new RuntimeException("Unknown checksum algorithm");
+ }
+ byte[] buffer = new byte[1024];
+ try (final InputStream is = Files.newInputStream(Paths.get(vocabFile.toString()));
+ DigestInputStream dis = new DigestInputStream(is, md)) {
+ while (dis.read(buffer) != -1) {}
+ } catch (IOException e) {
+ throw new RuntimeException("Can not find vocabulary file. This should not happen.");
+ }
+ byte[] digest = md.digest();
+ // convert the byte to hex format
+ StringBuffer sb = new StringBuffer("");
+ for (byte aDigest : digest) {
+ sb.append(Integer.toString((aDigest & 0xff) + 0x100, 16).substring(1));
+ }
+ return sb.toString();
+ }
+
+ /**
+ * PackedRoot represents the root of the packed grammar trie.
+ * Tries for different source-side firstwords are organized in
+ * packedSlices on disk. A packedSlice can contain multiple trie
+ * roots (i.e. multiple source-side firstwords).
+ * The PackedRoot builds a lookup table, mapping from
+ * source-side firstwords to the addresses in the packedSlices
+ * that represent the subtrie for a particular firstword.
+ * If the GrammarPacker has to distribute rules for a
+ * source-side firstword across multiple slices, a
- * SliceAggregatingTrie node is created that aggregates those
++ * SliceAggregatingTrie node is created that aggregates those
+ * tries to hide
+ * this additional complexity from the grammar interface
+ * This feature allows packing of grammars where the list of rules
+ * for a single source-side firstword would exceed the maximum array
+ * size of Java (2gb).
+ */
- public final class PackedRoot implements Trie {
++ public static final class PackedRoot implements Trie {
+
+ private final HashMap<Integer, Trie> lookup;
+
+ public PackedRoot(final List<PackedSlice> slices) {
+ final Map<Integer, List<Trie>> childTries = collectChildTries(slices);
+ lookup = buildLookupTable(childTries);
+ }
-
++
+ /**
- * Determines whether trie nodes for source first-words are spread over
++ * Determines whether trie nodes for source first-words are spread over
+ * multiple packedSlices by counting their occurrences.
+ * @param slices
+ * @return A mapping from first word ids to a list of trie nodes.
+ */
+ private Map<Integer, List<Trie>> collectChildTries(final List<PackedSlice> slices) {
+ final Map<Integer, List<Trie>> childTries = new HashMap<>();
+ for (PackedSlice packedSlice : slices) {
-
++
+ // number of tries stored in this packedSlice
+ final int num_children = packedSlice.source[0];
+ for (int i = 0; i < num_children; i++) {
+ final int id = packedSlice.source[2 * i + 1];
-
++
+ /* aggregate tries with same root id
+ * obtain a Trie node, already at the correct address in the packedSlice.
+ * In other words, the lookup index already points to the correct trie node in the packedSlice.
+ * packedRoot.match() thus can directly return the result of lookup.get(id);
+ */
+ if (!childTries.containsKey(id)) {
+ childTries.put(id, new ArrayList<>(1));
+ }
+ final Trie trie = packedSlice.root().match(id);
+ childTries.get(id).add(trie);
+ }
+ }
+ return childTries;
+ }
-
++
+ /**
+ * Build a lookup table for children tries.
+ * If the list contains only a single child node, a regular trie node
+ * is inserted into the table; otherwise a SliceAggregatingTrie node is
+ * created that hides this partitioning into multiple packedSlices
+ * upstream.
+ */
+ private HashMap<Integer,Trie> buildLookupTable(final Map<Integer, List<Trie>> childTries) {
+ HashMap<Integer,Trie> lookup = new HashMap<>(childTries.size());
+ for (int id : childTries.keySet()) {
+ final List<Trie> tries = childTries.get(id);
+ if (tries.size() == 1) {
+ lookup.put(id, tries.get(0));
+ } else {
+ lookup.put(id, new SliceAggregatingTrie(tries));
+ }
+ }
+ return lookup;
+ }
+
+ @Override
+ public Trie match(int word_id) {
+ return lookup.get(word_id);
+ }
+
+ @Override
+ public boolean hasExtensions() {
+ return !lookup.isEmpty();
+ }
+
+ @Override
+ public HashMap<Integer, ? extends Trie> getChildren() {
+ return lookup;
+ }
+
+ @Override
+ public ArrayList<? extends Trie> getExtensions() {
+ return new ArrayList<>(lookup.values());
+ }
+
+ @Override
+ public boolean hasRules() {
+ return false;
+ }
+
+ @Override
+ public RuleCollection getRuleCollection() {
+ return new BasicRuleCollection(0, new int[0]);
+ }
+
+ @Override
+ public Iterator<Integer> getTerminalExtensionIterator() {
+ return new ExtensionIterator(lookup, true);
+ }
+
+ @Override
+ public Iterator<Integer> getNonterminalExtensionIterator() {
+ return new ExtensionIterator(lookup, false);
+ }
+ }
+
+ public final class PackedSlice {
+ private final String name;
+
+ private final int[] source;
+ private final IntBuffer target;
+ private final ByteBuffer features;
+ private final ByteBuffer alignments;
+
+ private final int[] targetLookup;
+ private int featureSize;
+ private float[] estimated;
+
+ private final static int BUFFER_HEADER_POSITION = 8;
+
+ /**
+ * Provides a cache of packedTrie nodes to be used in getTrie.
+ */
+ private HashMap<Integer, PackedTrie> tries;
+
+ public PackedSlice(String prefix) throws IOException {
+ name = prefix;
+
+ File source_file = new File(prefix + ".source");
+ File target_file = new File(prefix + ".target");
+ File target_lookup_file = new File(prefix + ".target.lookup");
+ File feature_file = new File(prefix + ".features");
+ File alignment_file = new File(prefix + ".alignments");
+
+ source = fullyLoadFileToArray(source_file);
+ // First int specifies the size of this file, load from 1st int on
+ targetLookup = fullyLoadFileToArray(target_lookup_file, 1);
+
+ target = associateMemoryMappedFile(target_file).asIntBuffer();
+ features = associateMemoryMappedFile(feature_file);
+ initializeFeatureStructures();
+
+ if (alignment_file.exists()) {
+ alignments = associateMemoryMappedFile(alignment_file);
+ } else {
+ alignments = null;
+ }
+
+ tries = new HashMap<>();
+ }
+
+ /**
+ * Helper function to help create all the structures which describe features
+ * in the Slice. Only called during object construction.
+ */
+ private void initializeFeatureStructures() {
+ int num_blocks = features.getInt(0);
+ estimated = new float[num_blocks];
+ Arrays.fill(estimated, Float.NEGATIVE_INFINITY);
+ featureSize = features.getInt(4);
+ }
+
+ private int getIntFromByteBuffer(int position, ByteBuffer buffer) {
+ return buffer.getInt(BUFFER_HEADER_POSITION + (4 * position));
+ }
+
+ private int[] fullyLoadFileToArray(File file) throws IOException {
+ return fullyLoadFileToArray(file, 0);
+ }
+
+ /**
+ * This function will use a bulk loading method to fully populate a target
+ * array from file.
+ *
+ * @param file
+ * File that will be read from disk.
+ * @param startIndex
+ * an offset into the read file.
+ * @return an int array of size length(file) - offset containing ints in the
+ * file.
+ * @throws IOException
+ */
+ private int[] fullyLoadFileToArray(File file, int startIndex) throws IOException {
+ IntBuffer buffer = associateMemoryMappedFile(file).asIntBuffer();
+ int size = (int) (file.length() - (4 * startIndex))/4;
+ int[] result = new int[size];
+ buffer.position(startIndex);
+ buffer.get(result, 0, size);
+ return result;
+ }
+
+ private ByteBuffer associateMemoryMappedFile(File file) throws IOException {
+ try(FileInputStream fileInputStream = new FileInputStream(file)) {
+ FileChannel fileChannel = fileInputStream.getChannel();
+ int size = (int) fileChannel.size();
+ return fileChannel.map(MapMode.READ_ONLY, 0, size);
+ }
+ }
+
+ private final int[] getTargetArray(int pointer) {
+ // Figure out level.
+ int tgt_length = 1;
+ while (tgt_length < (targetLookup.length + 1) && targetLookup[tgt_length] <= pointer)
+ tgt_length++;
+ int[] tgt = new int[tgt_length];
+ int index = 0;
+ int parent;
+ do {
+ parent = target.get(pointer);
+ if (parent != -1)
+ tgt[index++] = target.get(pointer + 1);
+ pointer = parent;
+ } while (pointer != -1);
+ return tgt;
+ }
+
+ private synchronized PackedTrie getTrie(final int node_address) {
+ PackedTrie t = tries.get(node_address);
+ if (t == null) {
+ t = new PackedTrie(node_address);
+ tries.put(node_address, t);
+ }
+ return t;
+ }
+
+ private synchronized PackedTrie getTrie(int node_address, int[] parent_src, int parent_arity,
+ int symbol) {
+ PackedTrie t = tries.get(node_address);
+ if (t == null) {
+ t = new PackedTrie(node_address, parent_src, parent_arity, symbol);
+ tries.put(node_address, t);
+ }
+ return t;
+ }
+
+ /**
+ * Returns the FeatureVector associated with a rule (represented as a block ID).
+ * The feature ids are hashed corresponding to feature names prepended with the owner string:
+ * i.e. '0' becomes '<owner>_0'.
+ * @param block_id
+ * @return feature vector
+ */
+ private final FeatureVector loadFeatureVector(int block_id, OwnerId ownerId) {
+ int featurePosition = getIntFromByteBuffer(block_id, features);
+ final int numFeatures = encoding.readId(features, featurePosition);
+
+ featurePosition += EncoderConfiguration.ID_SIZE;
+ final FeatureVector featureVector = new FeatureVector(encoding.getNumDenseFeatures());
+ FloatEncoder encoder;
+
+ for (int i = 0; i < numFeatures; i++) {
+ final int innerId = encoding.readId(features, featurePosition);
+ encoder = encoding.encoder(innerId);
+ final int outerId = encoding.outerId(innerId);
+ final int ownedFeatureId = hashFeature(getFeature(outerId), ownerId);
+ final float value = encoder.read(features, featurePosition);
-
++
+ featureVector.add(ownedFeatureId, value);
+ featurePosition += EncoderConfiguration.ID_SIZE + encoder.size();
+ }
-
++
+ return featureVector;
+ }
+
+ /**
+ * We need to synchronize this method as there is a many to one ratio between
+ * PackedRule/PhrasePair and this class (PackedSlice). This means during concurrent first
+ * getAlignments calls to PackedRule objects they could alter each other's positions within the
+ * buffer before calling read on the buffer.
+ */
+ private synchronized byte[] getAlignmentArray(int block_id) {
+ if (alignments == null)
+ throw new RuntimeException("No alignments available.");
+ int alignment_position = getIntFromByteBuffer(block_id, alignments);
- int num_points = (int) alignments.get(alignment_position);
++ int num_points = alignments.get(alignment_position);
+ byte[] alignment = new byte[num_points * 2];
+
+ alignments.position(alignment_position + 1);
+ try {
+ alignments.get(alignment, 0, num_points * 2);
+ } catch (BufferUnderflowException bue) {
+ LOG.warn("Had an exception when accessing alignment mapped byte buffer");
+ LOG.warn("Attempting to access alignments at position: {}", alignment_position + 1);
+ LOG.warn("And to read this many bytes: {}", num_points * 2);
+ LOG.warn("Buffer capacity is : {}", alignments.capacity());
+ LOG.warn("Buffer position is : {}", alignments.position());
+ LOG.warn("Buffer limit is : {}", alignments.limit());
+ throw bue;
+ }
+ return alignment;
+ }
+
+ private PackedTrie root() {
+ return getTrie(0);
+ }
+
++ @Override
+ public String toString() {
+ return name;
+ }
+
+ /**
+ * A trie node within the grammar slice. Identified by its position within the source array,
+ * and, as a supplement, the source string leading from the trie root to the node.
- *
++ *
+ * @author jg
- *
++ *
+ */
+ public class PackedTrie implements Trie, RuleCollection {
+
+ private final int position;
+
+ private boolean sorted = false;
+
+ private final int[] src;
+ private int arity;
+
+ private PackedTrie(int position) {
+ this.position = position;
+ src = new int[0];
+ arity = 0;
+ }
+
+ private PackedTrie(int position, int[] parent_src, int parent_arity, int symbol) {
+ this.position = position;
+ src = new int[parent_src.length + 1];
+ System.arraycopy(parent_src, 0, src, 0, parent_src.length);
+ src[src.length - 1] = symbol;
+ arity = parent_arity;
+ if (FormatUtils.isNonterminal(symbol))
+ arity++;
+ }
+
+ @Override
+ public final Trie match(int token_id) {
+ int num_children = source[position];
+ if (num_children == 0)
+ return null;
+ if (num_children == 1 && token_id == source[position + 1])
+ return getTrie(source[position + 2], src, arity, token_id);
+ int top = 0;
+ int bottom = num_children - 1;
+ while (true) {
+ int candidate = (top + bottom) / 2;
+ int candidate_position = position + 1 + 2 * candidate;
+ int read_token = source[candidate_position];
+ if (read_token == token_id) {
+ return getTrie(source[candidate_position + 1], src, arity, token_id);
+ } else if (top == bottom) {
+ return null;
+ } else if (read_token > token_id) {
+ top = candidate + 1;
+ } else {
+ bottom = candidate - 1;
+ }
+ if (bottom < top)
+ return null;
+ }
+ }
+
+ @Override
+ public HashMap<Integer, ? extends Trie> getChildren() {
+ HashMap<Integer, Trie> children = new HashMap<>();
+ int num_children = source[position];
+ for (int i = 0; i < num_children; i++) {
+ int symbol = source[position + 1 + 2 * i];
+ int address = source[position + 2 + 2 * i];
+ children.put(symbol, getTrie(address, src, arity, symbol));
+ }
+ return children;
+ }
+
+ @Override
+ public boolean hasExtensions() {
+ return (source[position] != 0);
+ }
+
+ @Override
+ public ArrayList<? extends Trie> getExtensions() {
+ int num_children = source[position];
+ ArrayList<PackedTrie> tries = new ArrayList<>(num_children);
+
+ for (int i = 0; i < num_children; i++) {
+ int symbol = source[position + 1 + 2 * i];
+ int address = source[position + 2 + 2 * i];
+ tries.add(getTrie(address, src, arity, symbol));
+ }
+
+ return tries;
+ }
+
+ @Override
+ public boolean hasRules() {
+ int num_children = source[position];
+ return (source[position + 1 + 2 * num_children] != 0);
+ }
+
+ @Override
+ public RuleCollection getRuleCollection() {
+ return this;
+ }
+
+ @Override
+ public List<Rule> getRules() {
+ List<Rule> rules = cached_rules.getIfPresent(this);
+ if (rules != null) {
+ return rules;
+ }
+
+ int num_children = source[position];
+ int rule_position = position + 2 * (num_children + 1);
+ int num_rules = source[rule_position - 1];
+
+ rules = new ArrayList<>(num_rules);
+ for (int i = 0; i < num_rules; i++) {
+ rules.add(new PackedRule(rule_position + 3 * i));
+ }
+
+ cached_rules.put(this, rules);
+ return rules;
+ }
+
+ /**
+ * We determine if the Trie is sorted by checking if the estimated cost of the first rule in
+ * the trie has been set.
+ */
+ @Override
+ public boolean isSorted() {
+ return sorted;
+ }
+
+ /**
+ * Estimates rule costs for all rules at this trie node.
+ */
+ private synchronized void sortRules(List<FeatureFunction> featureFunctions) {
+ int num_children = source[position];
+ int rule_position = position + 2 * (num_children + 1);
+ int num_rules = source[rule_position - 1];
+ if (num_rules == 0) {
+ this.sorted = true;
+ return;
+ }
+ final Integer[] rules = new Integer[num_rules];
+
+ int block_id;
+ int lhs;
+ int[] target;
+ byte[] alignments;
+ FeatureVector features;
-
++
+ for (int i = 0; i < num_rules; ++i) {
+ // we construct very short-lived rule objects for sorting
+ rules[i] = rule_position + 2 + 3 * i;
+ block_id = source[rules[i]];
+ lhs = source[rule_position + 3 * i];
+ target = getTargetArray(source[rule_position + 1 + 3 * i]);
+ features = loadFeatureVector(block_id, owner);
+ alignments = getAlignmentArray(block_id);
+ final Rule rule = new Rule(lhs, src, target, arity, features, alignments, owner);
+ estimated[block_id] = rule.estimateRuleCost(featureFunctions);
+ }
+
+ Arrays.sort(rules, (a, b) -> {
+ float a_cost = estimated[source[a]];
+ float b_cost = estimated[source[b]];
+ if (a_cost == b_cost)
+ return 0;
+ return (a_cost > b_cost ? -1 : 1);
+ });
+
+ int[] sorted = new int[3 * num_rules];
+ int j = 0;
+ for (Integer address : rules) {
+ sorted[j++] = source[address - 2];
+ sorted[j++] = source[address - 1];
+ sorted[j++] = source[address];
+ }
+ System.arraycopy(sorted, 0, source, rule_position + 0, sorted.length);
+
+ // Replace rules in cache with their sorted values on next getRules()
+ cached_rules.invalidate(this);
+ this.sorted = true;
+ }
+
+ @Override
+ public List<Rule> getSortedRules(List<FeatureFunction> featureFunctions) {
+ if (!isSorted())
+ sortRules(featureFunctions);
+ return getRules();
+ }
+
+ @Override
+ public int[] getSourceSide() {
+ return src;
+ }
+
+ @Override
+ public int getArity() {
+ return arity;
+ }
+
+ @Override
+ public Iterator<Integer> getTerminalExtensionIterator() {
+ return new PackedChildIterator(position, true);
+ }
+
+ @Override
+ public Iterator<Integer> getNonterminalExtensionIterator() {
+ return new PackedChildIterator(position, false);
+ }
+
+ public final class PackedChildIterator implements Iterator<Integer> {
+
+ private int current;
+ private final boolean terminal;
+ private boolean done;
+ private int last;
+
+ PackedChildIterator(int position, boolean terminal) {
+ this.terminal = terminal;
+ int num_children = source[position];
+ done = (num_children == 0);
+ if (!done) {
+ current = (terminal ? position + 1 : position - 1 + 2 * num_children);
+ last = (terminal ? position - 1 + 2 * num_children : position + 1);
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (done)
+ return false;
+ int next = (terminal ? current + 2 : current - 2);
+ if (next == last)
+ return false;
+ return (terminal ? source[next] > 0 : source[next] < 0);
+ }
+
+ @Override
+ public Integer next() {
+ if (done)
+ throw new RuntimeException("No more symbols!");
+ int symbol = source[current];
+ if (current == last)
+ done = true;
+ if (!done) {
+ current = (terminal ? current + 2 : current - 2);
+ done = (terminal ? source[current] < 0 : source[current] > 0);
+ }
+ return symbol;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
-
++
+ /**
+ * A packed phrase pair represents a rule of the form of a phrase pair, packed with the
+ * grammar-packer.pl script, which simply adds a nonterminal [X] to the left-hand side of
+ * all phrase pairs (and converts the Moses features). The packer then packs these. We have
+ * to then put a nonterminal on the source and target sides to treat the phrase pairs like
- * left-branching rules, which is how Joshua deals with phrase decoding.
- *
++ * left-branching rules, which is how Joshua deals with phrase decoding.
++ *
+ * @author Matt Post post@cs.jhu.edu
+ *
+ */
+ public final class PackedPhrasePair extends PackedRule {
+
+ private final Supplier<int[]> targetSupplier;
+ private final Supplier<byte[]> alignmentSupplier;
+
+ public PackedPhrasePair(int address) {
+ super(address);
+ targetSupplier = initializeTargetSupplier();
+ alignmentSupplier = initializeAlignmentSupplier();
+ }
+
+ @Override
+ public int getArity() {
+ return PackedTrie.this.getArity() + 1;
+ }
+
+ /**
+ * Initialize a number of suppliers which get evaluated when their respective getters
+ * are called.
+ * Inner lambda functions are guaranteed to only be called once, because of this underlying
+ * structures are accessed in a threadsafe way.
+ * Guava's implementation makes sure only one read of a volatile variable occurs per get.
+ * This means this implementation should be as thread-safe and performant as possible.
+ */
+ private Supplier<int[]> initializeTargetSupplier(){
+ Supplier<int[]> result = Suppliers.memoize(() ->{
+ int[] phrase = getTargetArray(source[address + 1]);
+ int[] tgt = new int[phrase.length + 1];
+ tgt[0] = -1;
+ for (int i = 0; i < phrase.length; i++)
+ tgt[i+1] = phrase[i];
+ return tgt;
+ });
+ return result;
+ }
+
+ private Supplier<byte[]> initializeAlignmentSupplier(){
+ return Suppliers.memoize(() ->{
+ byte[] raw_alignment = getAlignmentArray(source[address + 2]);
+ byte[] points = new byte[raw_alignment.length + 2];
+ points[0] = points[1] = 0;
+ for (int i = 0; i < raw_alignment.length; i++)
+ points[i + 2] = (byte) (raw_alignment[i] + 1);
+ return points;
+ });
+ }
+
+ /**
+ * Take the target phrase of the underlying rule and prepend an [X].
- *
++ *
+ * @return the augmented phrase
+ */
+ @Override
+ public int[] getTarget() {
+ return this.targetSupplier.get();
+ }
-
++
+ /**
+ * Take the source phrase of the underlying rule and prepend an [X].
- *
++ *
+ * @return the augmented source phrase
+ */
+ @Override
+ public int[] getSource() {
+ int phrase[] = new int[src.length + 1];
+ int ntid = Vocabulary.id(PackedGrammar.this.joshuaConfiguration.default_non_terminal);
+ phrase[0] = ntid;
+ System.arraycopy(src, 0, phrase, 1, src.length);
+ return phrase;
+ }
-
++
+ /**
+ * Similarly the alignment array needs to be shifted over by one.
- *
++ *
+ * @return the byte[] alignment
+ */
+ @Override
+ public byte[] getAlignment() {
+ // if no alignments in grammar do not fail
+ if (alignments == null) {
+ return null;
+ }
+
+ return this.alignmentSupplier.get();
+ }
+ }
+
+ public class PackedRule extends Rule {
+ protected final int address;
+ private final Supplier<int[]> targetSupplier;
+ private final Supplier<FeatureVector> featureVectorSupplier;
+ private final Supplier<byte[]> alignmentsSupplier;
+
+ public PackedRule(int address) {
+ super(source[address], src, null, PackedTrie.this.getArity(), null, null, owner);
+ this.address = address;
+ this.targetSupplier = intializeTargetSupplier();
+ this.featureVectorSupplier = initializeFeatureVectorSupplier();
+ this.alignmentsSupplier = initializeAlignmentsSupplier();
+ }
+
+ private Supplier<int[]> intializeTargetSupplier(){
+ Supplier<int[]> result = Suppliers.memoize(() ->{
+ return getTargetArray(source[address + 1]);
+ });
+ return result;
+ }
+
+ private Supplier<FeatureVector> initializeFeatureVectorSupplier(){
+ Supplier<FeatureVector> result = Suppliers.memoize(() ->{
+ return loadFeatureVector(source[address + 2], owner);
+ });
+ return result;
+ }
+
+ private Supplier<byte[]> initializeAlignmentsSupplier(){
+ return Suppliers.memoize(()->{
+ // if no alignments in grammar do not fail
+ if (alignments == null){
+ return null;
+ }
+ return getAlignmentArray(source[address + 2]);
+ });
+ }
+
+ @Override
+ public int getArity() {
+ return PackedTrie.this.getArity();
+ }
+
+ @Override
+ public int getLHS() {
+ return source[address];
+ }
+
+ @Override
+ public int[] getTarget() {
+ return this.targetSupplier.get();
+ }
+
+ @Override
+ public int[] getSource() {
+ return src;
+ }
+
+ @Override
+ public FeatureVector getFeatureVector() {
+ return this.featureVectorSupplier.get();
+ }
-
++
+ @Override
+ public byte[] getAlignment() {
+ return this.alignmentsSupplier.get();
+ }
+
+ @Override
+ public float getEstimatedCost() {
+ return estimated[source[address + 2]];
+ }
+
+ @Override
+ public float estimateRuleCost(List<FeatureFunction> models) {
+ return estimated[source[address + 2]];
+ }
+ }
+ }
+ }
+
+ @Override
+ public void addOOVRules(int word, List<FeatureFunction> featureFunctions) {
+ throw new RuntimeException("PackedGrammar.addOOVRules(): I can't add OOV rules");
+ }
-
++
+ @Override
+ public void addRule(Rule rule) {
+ throw new RuntimeException("PackedGrammar.addRule(): I can't add rules");
+ }
-
- /**
++
++ /**
+ * Read the config file
- *
++ *
+ * TODO: this should be rewritten using typeconfig.
- *
++ *
+ * @param config
+ * @throws IOException
+ */
+ private void readConfig(String config) throws IOException {
+ int version = 0;
-
++
+ for (String line: new LineReader(config)) {
+ String[] tokens = line.split(" = ");
+ if (tokens[0].equals("max-source-len"))
+ this.maxSourcePhraseLength = Integer.parseInt(tokens[1]);
+ else if (tokens[0].equals("version")) {
+ version = Integer.parseInt(tokens[1]);
+ }
+ }
-
++
+ if (version != 3) {
+ String message = String.format("The grammar at %s was packed with packer version %d, but the earliest supported version is %d",
+ this.grammarDir, version, SUPPORTED_VERSION);
+ throw new RuntimeException(message);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/lattice/Lattice.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/lattice/Lattice.java
index 2332159,0000000..c557c07
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/lattice/Lattice.java
+++ b/joshua-core/src/main/java/org/apache/joshua/lattice/Lattice.java
@@@ -1,587 -1,0 +1,541 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.lattice;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
- import java.util.Stack;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.segment_file.Token;
+import org.apache.joshua.util.ChartSpan;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A lattice representation of a directed graph.
+ *
+ * @author Lane Schwartz
+ * @author Matt Post post@cs.jhu.edu
+ * @since 2008-07-08
+ *
+ */
+public class Lattice<Value> implements Iterable<Node<Value>> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Lattice.class);
+
+ /**
+ * True if there is more than one path through the lattice.
+ */
+ private boolean latticeHasAmbiguity;
+
+ /**
+ * Costs of the best path between each pair of nodes in the lattice.
+ */
+ private ChartSpan<Integer> distances = null;
+
+ /**
+ * List of all nodes in the lattice. Nodes are assumed to be in topological order.
+ */
+ private List<Node<Value>> nodes;
+
+
+ JoshuaConfiguration config = null;
+
+ /**
+ * Constructs a new lattice from an existing list of (connected) nodes.
+ * <p>
+ * The list of nodes must already be in topological order. If the list is not in topological
+ * order, the behavior of the lattice is not defined.
+ *
+ * @param nodes A list of nodes which must be in topological order.
+ * @param config a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ */
+ public Lattice(List<Node<Value>> nodes, JoshuaConfiguration config) {
+ this.nodes = nodes;
+ // this.distances = calculateAllPairsShortestPath();
+ this.latticeHasAmbiguity = true;
+ }
+
+ public Lattice(List<Node<Value>> nodes, boolean isAmbiguous, JoshuaConfiguration config) {
+ // Node<Value> sink = new Node<Value>(nodes.size());
+ // nodes.add(sink);
+ this.nodes = nodes;
+ // this.distances = calculateAllPairsShortestPath();
+ this.latticeHasAmbiguity = isAmbiguous;
+ }
+
+ /**
+ * Instantiates a lattice from a linear chain of values, i.e., a sentence.
+ *
+ * @param linearChain a sequence of Value objects
+ * @param config a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ */
+ public Lattice(Value[] linearChain, JoshuaConfiguration config) {
+ this.latticeHasAmbiguity = false;
+ this.nodes = new ArrayList<Node<Value>>();
+
+ Node<Value> previous = new Node<Value>(0);
+ nodes.add(previous);
+
+ int i = 1;
+
+ for (Value value : linearChain) {
+
+ Node<Value> current = new Node<Value>(i);
+ float cost = 0.0f;
+ // if (i > 4) cost = (float)i/1.53432f;
+ previous.addArc(current, cost, value);
+
+ nodes.add(current);
+
+ previous = current;
+ i++;
+ }
+
+ // this.distances = calculateAllPairsShortestPath();
+ }
+
+ public final boolean hasMoreThanOnePath() {
+ return latticeHasAmbiguity;
+ }
+
+ /**
+ * Computes the shortest distance between two nodes, which is used (perhaps among other places) in
+ * computing which rules can apply over which spans of the input
+ *
+ * @param arc an {@link org.apache.joshua.lattice.Arc} of values
+ * @return the shortest distance between two nodes
+ */
+ public int distance(Arc<Value> arc) {
+ return this.getShortestPath(arc.getTail().getNumber(), arc.getHead().getNumber());
+ }
+
+ public int distance(int i, int j) {
+ return this.getShortestPath(i, j);
+ }
+
+ /**
+ * Convenience method to get a lattice from a linear sequence of {@link Token} objects.
+ *
+ * @param source input string from which to create a {@link org.apache.joshua.lattice.Lattice}
+ * @param config a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ * @return Lattice representation of the linear chain.
+ */
+ public static Lattice<Token> createTokenLatticeFromString(String source, JoshuaConfiguration config) {
+ String[] tokens = source.split("\\s+");
+ Token[] integerSentence = new Token[tokens.length];
+ for (int i = 0; i < tokens.length; i++) {
+ integerSentence[i] = new Token(tokens[i], config);
+ }
+
+ return new Lattice<Token>(integerSentence, config);
+ }
+
+ public static Lattice<Token> createTokenLatticeFromPLF(String data, JoshuaConfiguration config) {
+ ArrayList<Node<Token>> nodes = new ArrayList<Node<Token>>();
+
+ // This matches a sequence of tuples, which describe arcs leaving this node
+ Pattern nodePattern = Pattern.compile("(.+?)\\(\\s*(\\(.+?\\),\\s*)\\s*\\)(.*)");
+
+ /*
+ * This matches a comma-delimited, parenthesized tuple of a (a) single-quoted word (b) a number,
+ * optionally in scientific notation (c) an offset (how many states to jump ahead)
+ */
+ Pattern arcPattern = Pattern
+ .compile("\\s*\\('(.+?)',\\s*(-?\\d+\\.?\\d*?(?:[eE]-?\\d+)?),\\s*(\\d+)\\),\\s*(.*)");
+
+ Matcher nodeMatcher = nodePattern.matcher(data);
+
+ boolean latticeIsAmbiguous = false;
+
+ int nodeID = 0;
+ Node<Token> startNode = new Node<Token>(nodeID);
+ nodes.add(startNode);
+
+ while (nodeMatcher.matches()) {
+
+ String nodeData = nodeMatcher.group(2);
+ String remainingData = nodeMatcher.group(3);
+
+ nodeID++;
+
+ Node<Token> currentNode = null;
+ if (nodeID < nodes.size() && nodes.get(nodeID) != null) {
+ currentNode = nodes.get(nodeID);
+ } else {
+ currentNode = new Node<Token>(nodeID);
+ while (nodeID > nodes.size())
+ nodes.add(new Node<Token>(nodes.size()));
+ nodes.add(currentNode);
+ }
+
+ Matcher arcMatcher = arcPattern.matcher(nodeData);
+ int numArcs = 0;
+ if (!arcMatcher.matches()) {
+ throw new RuntimeException("Parse error!");
+ }
+ while (arcMatcher.matches()) {
+ numArcs++;
+ String arcLabel = arcMatcher.group(1);
+ float arcWeight = Float.parseFloat(arcMatcher.group(2));
+ int destinationNodeID = nodeID + Integer.parseInt(arcMatcher.group(3));
+
+ Node<Token> destinationNode;
+ if (destinationNodeID < nodes.size() && nodes.get(destinationNodeID) != null) {
+ destinationNode = nodes.get(destinationNodeID);
+ } else {
+ destinationNode = new Node<Token>(destinationNodeID);
+ while (destinationNodeID > nodes.size())
+ nodes.add(new Node<Token>(nodes.size()));
+ nodes.add(destinationNode);
+ }
+
+ String remainingArcs = arcMatcher.group(4);
+
+ Token arcToken = new Token(arcLabel, config);
+ currentNode.addArc(destinationNode, arcWeight, arcToken);
+
+ arcMatcher = arcPattern.matcher(remainingArcs);
+ }
+ if (numArcs > 1)
+ latticeIsAmbiguous = true;
+
+ nodeMatcher = nodePattern.matcher(remainingData);
+ }
+
+ /* Add <s> to the start of the lattice. */
+ if (nodes.size() > 1 && nodes.get(1) != null) {
+ Node<Token> firstNode = nodes.get(1);
+ startNode.addArc(firstNode, 0.0f, new Token(Vocabulary.START_SYM, config));
+ }
+
+ /* Add </s> as a final state, connect it to the previous end-state */
+ nodeID = nodes.get(nodes.size()-1).getNumber() + 1;
+ Node<Token> endNode = new Node<Token>(nodeID);
+ nodes.get(nodes.size()-1).addArc(endNode, 0.0f, new Token(Vocabulary.STOP_SYM, config));
+ nodes.add(endNode);
+
+ return new Lattice<Token>(nodes, latticeIsAmbiguous, config);
+ }
+
+ /**
+ * Constructs a lattice from a given string representation.
+ *
+ * @param data String representation of a lattice.
+ * @param config a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ * @return A lattice that corresponds to the given string.
+ */
+ public static Lattice<String> createStringLatticeFromString(String data, JoshuaConfiguration config) {
+
+ Map<Integer, Node<String>> nodes = new HashMap<Integer, Node<String>>();
+
+ Pattern nodePattern = Pattern.compile("(.+?)\\((\\(.+?\\),)\\)(.*)");
+ Pattern arcPattern = Pattern.compile("\\('(.+?)',(\\d+.\\d+),(\\d+)\\),(.*)");
+
+ Matcher nodeMatcher = nodePattern.matcher(data);
+
+ int nodeID = -1;
+
+ while (nodeMatcher.matches()) {
+
+ String nodeData = nodeMatcher.group(2);
+ String remainingData = nodeMatcher.group(3);
+
+ nodeID++;
+
+ Node<String> currentNode;
+ if (nodes.containsKey(nodeID)) {
+ currentNode = nodes.get(nodeID);
+ } else {
+ currentNode = new Node<String>(nodeID);
+ nodes.put(nodeID, currentNode);
+ }
+
+ LOG.debug("Node : {}", nodeID);
+
+ Matcher arcMatcher = arcPattern.matcher(nodeData);
+
+ while (arcMatcher.matches()) {
+ String arcLabel = arcMatcher.group(1);
+ float arcWeight = Float.valueOf(arcMatcher.group(2));
+ int destinationNodeID = nodeID + Integer.parseInt(arcMatcher.group(3));
+
+ Node<String> destinationNode;
+ if (nodes.containsKey(destinationNodeID)) {
+ destinationNode = nodes.get(destinationNodeID);
+ } else {
+ destinationNode = new Node<String>(destinationNodeID);
+ nodes.put(destinationNodeID, destinationNode);
+ }
+
+ String remainingArcs = arcMatcher.group(4);
+
+ LOG.debug("\t{} {} {}", arcLabel, arcWeight, destinationNodeID);
+
+ currentNode.addArc(destinationNode, arcWeight, arcLabel);
+
+ arcMatcher = arcPattern.matcher(remainingArcs);
+ }
+
+ nodeMatcher = nodePattern.matcher(remainingData);
+ }
+
+ List<Node<String>> nodeList = new ArrayList<Node<String>>(nodes.values());
+ Collections.sort(nodeList, new NodeIdentifierComparator());
+
+ LOG.debug("Nodelist={}", nodeList);
+
+ return new Lattice<String>(nodeList, config);
+ }
+
+ /**
+ * Gets the cost of the shortest path between two nodes.
+ *
+ * @param from ID of the starting node.
+ * @param to ID of the ending node.
+ * @return The cost of the shortest path between the two nodes.
+ */
+ public int getShortestPath(int from, int to) {
+ // System.err.println(String.format("DISTANCE(%d,%d) = %f", from, to, costs[from][to]));
+ if (distances == null)
+ this.distances = calculateAllPairsShortestPath();
+
+ return distances.get(from, to);
+ }
+
+ /**
+ * Gets the shortest distance through the lattice.
+ * @return int representing the shortest distance through the lattice
+ */
+ public int getShortestDistance() {
+ if (distances == null)
+ distances = calculateAllPairsShortestPath();
+ return distances.get(0, nodes.size()-1);
+ }
+
+ /**
+ * Gets the node with a specified integer identifier. If the identifier is negative, we count
+ * backwards from the end of the array, Perl-style (-1 is the last element, -2 the penultimate,
+ * etc).
+ *
+ * @param index Integer identifier for a node.
+ * @return The node with the specified integer identifier
+ */
+ public Node<Value> getNode(int index) {
+ if (index >= 0)
+ return nodes.get(index);
+ else
+ return nodes.get(size() + index);
+ }
+
+ public List<Node<Value>> getNodes() {
+ return nodes;
+ }
+
+ /**
+ * Returns an iterator over the nodes in this lattice.
+ *
+ * @return An iterator over the nodes in this lattice.
+ */
++ @Override
+ public Iterator<Node<Value>> iterator() {
+ return nodes.iterator();
+ }
+
+ /**
+ * Returns the number of nodes in this lattice.
+ *
+ * @return The number of nodes in this lattice.
+ */
+ public int size() {
+ return nodes.size();
+ }
+
+ /**
+ * Calculate the all-pairs shortest path for all pairs of nodes.
+ * <p>
+ * Note: This method assumes no backward arcs. If there are backward arcs, the returned shortest
+ * path costs for that node may not be accurate.
+ *
+ * @param nodes A list of nodes which must be in topological order.
+ * @return The all-pairs shortest path for all pairs of nodes.
+ */
+ private ChartSpan<Integer> calculateAllPairsShortestPath() {
+
+ ChartSpan<Integer> distance = new ChartSpan<Integer>(nodes.size() - 1, Integer.MAX_VALUE);
+ distance.setDiagonal(0);
+
+ /* Mark reachability between immediate neighbors */
+ for (Node<Value> tail : nodes) {
+ for (Arc<Value> arc : tail.getOutgoingArcs()) {
+ Node<Value> head = arc.getHead();
+ distance.set(tail.id(), head.id(), 1);
+ }
+ }
+
+ int size = nodes.size();
+
+ for (int width = 2; width <= size; width++) {
+ for (int i = 0; i < size - width; i++) {
+ int j = i + width;
+ for (int k = i + 1; k < j; k++) {
+ distance.set(i, j, Math.min(distance.get(i, j), distance.get(i, k) + distance.get(k, j)));
+ }
+ }
+ }
+
+ return distance;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder s = new StringBuilder();
+
+ for (Node<Value> start : this) {
+ for (Arc<Value> arc : start.getOutgoingArcs()) {
+ s.append(arc.toString());
+ s.append('\n');
+ }
+ }
+
+ return s.toString();
+ }
+
+ public static void main(String[] args) {
+
+ List<Node<String>> nodes = new ArrayList<Node<String>>();
+ for (int i = 0; i < 4; i++) {
+ nodes.add(new Node<String>(i));
+ }
+
+ nodes.get(0).addArc(nodes.get(1), 1.0f, "x");
+ nodes.get(1).addArc(nodes.get(2), 1.0f, "y");
+ nodes.get(0).addArc(nodes.get(2), 1.5f, "a");
+ nodes.get(2).addArc(nodes.get(3), 3.0f, "b");
+ nodes.get(2).addArc(nodes.get(3), 5.0f, "c");
+
+ Lattice<String> graph = new Lattice<String>(nodes, null);
+
+ System.out.println("Shortest path from 0 to 3: " + graph.getShortestPath(0, 3));
+ }
+
+ /**
+ * Replaced the arc from node i to j with the supplied lattice. This is used to do OOV
+ * segmentation of words in a lattice.
+ *
+ * @param i start node of arc
+ * @param j end node of arc
+ * @param newNodes new nodes used within the replacement operation
+ */
+ public void insert(int i, int j, List<Node<Value>> newNodes) {
+
+ nodes.get(i).setOutgoingArcs(newNodes.get(0).getOutgoingArcs());
+
+ newNodes.remove(0);
+ nodes.remove(j);
+ Collections.reverse(newNodes);
+
+ for (Node<Value> node: newNodes)
+ nodes.add(j, node);
+
+ this.latticeHasAmbiguity = false;
+ for (int x = 0; x < nodes.size(); x++) {
+ nodes.get(x).setID(x);
+ this.latticeHasAmbiguity |= (nodes.get(x).getOutgoingArcs().size() > 1);
+ }
+
+ this.distances = null;
+ }
+
+ /**
- * Topologically sorts the nodes and reassigns their numbers. Assumes that the first node is the
- * source, but otherwise assumes nothing about the input.
- *
- * Probably correct, but untested.
- */
- @SuppressWarnings("unused")
- private void topologicalSort() {
- HashMap<Node<Value>, List<Arc<Value>>> outgraph = new HashMap<Node<Value>, List<Arc<Value>>>();
- HashMap<Node<Value>, List<Arc<Value>>> ingraph = new HashMap<Node<Value>, List<Arc<Value>>>();
- for (Node<Value> node: nodes) {
- ArrayList<Arc<Value>> arcs = new ArrayList<Arc<Value>>();
- for (Arc<Value> arc: node.getOutgoingArcs()) {
- arcs.add(arc);
-
- if (! ingraph.containsKey(arc.getHead()))
- ingraph.put(arc.getHead(), new ArrayList<Arc<Value>>());
- ingraph.get(arc.getHead()).add(arc);
-
- outgraph.put(node, arcs);
- }
- }
-
- ArrayList<Node<Value>> sortedNodes = new ArrayList<Node<Value>>();
- Stack<Node<Value>> stack = new Stack<Node<Value>>();
- stack.push(nodes.get(0));
-
- while (! stack.empty()) {
- Node<Value> node = stack.pop();
- sortedNodes.add(node);
- for (Arc<Value> arc: outgraph.get(node)) {
- outgraph.get(node).remove(arc);
- ingraph.get(arc.getHead()).remove(arc);
-
- if (ingraph.get(arc.getHead()).size() == 0)
- sortedNodes.add(arc.getHead());
- }
- }
-
- int id = 0;
- for (Node<Value> node : sortedNodes)
- node.setID(id++);
-
- this.nodes = sortedNodes;
- }
-
- /**
- * Constructs a lattice from a given string representation.
++ * Constructs a lattice from a given string representation.
+ *
- * @param data String representation of a lattice.
- * @return A lattice that corresponds to the given string.
++ * @param data String representation of a lattice.
++ * @return A lattice that corresponds to the given string.
+ */
+ public static Lattice<String> createFromString(String data) {
+
+ Map<Integer,Node<String>> nodes = new HashMap<Integer,Node<String>>();
+
+ Pattern nodePattern = Pattern.compile("(.+?)\\((\\(.+?\\),)\\)(.*)");
+ Pattern arcPattern = Pattern.compile("\\('(.+?)',(\\d+.\\d+),(\\d+)\\),(.*)");
+
+ Matcher nodeMatcher = nodePattern.matcher(data);
+
+ int nodeID = -1;
+
+ while (nodeMatcher.matches()) {
+
+ String nodeData = nodeMatcher.group(2);
+ String remainingData = nodeMatcher.group(3);
+
+ nodeID++;
+
+ Node<String> currentNode;
+ if (nodes.containsKey(nodeID)) {
+ currentNode = nodes.get(nodeID);
+ } else {
+ currentNode = new Node<String>(nodeID);
+ nodes.put(nodeID, currentNode);
+ }
+
+ LOG.debug("Node : {}", nodeID);
+
+ Matcher arcMatcher = arcPattern.matcher(nodeData);
+
+ while (arcMatcher.matches()) {
+ String arcLabel = arcMatcher.group(1);
+ double arcWeight = Double.valueOf(arcMatcher.group(2));
+ int destinationNodeID = nodeID + Integer.valueOf(arcMatcher.group(3));
+
+ Node<String> destinationNode;
+ if (nodes.containsKey(destinationNodeID)) {
+ destinationNode = nodes.get(destinationNodeID);
+ } else {
+ destinationNode = new Node<String>(destinationNodeID);
+ nodes.put(destinationNodeID, destinationNode);
+ }
+
+ String remainingArcs = arcMatcher.group(4);
+
+ LOG.debug("\t {} {} {}", arcLabel, arcWeight, destinationNodeID);
+
+ currentNode.addArc(destinationNode, (float) arcWeight, arcLabel);
+
+ arcMatcher = arcPattern.matcher(remainingArcs);
+ }
+
+ nodeMatcher = nodePattern.matcher(remainingData);
+ }
+
+ List<Node<String>> nodeList = new ArrayList<Node<String>>(nodes.values());
+ Collections.sort(nodeList, new NodeIdentifierComparator());
+
+ LOG.debug("Nodelist={}", nodeList);
+
+ return new Lattice<String>(nodeList, new JoshuaConfiguration());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/metrics/CHRF.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/metrics/CHRF.java
index d67f6e0,0000000..dcf606a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/metrics/CHRF.java
+++ b/joshua-core/src/main/java/org/apache/joshua/metrics/CHRF.java
@@@ -1,308 -1,0 +1,303 @@@
+/*
+ * Copyright 2016 The Apache Software Foundation.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.joshua.metrics;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.logging.Logger;
+
+
+/**
+ *
+ * An implementation of the chrF evaluation metric for tuning.
+ * It is based on the original code by Maja Popovic [1] with the following main modifications:
+ * - Adapted to extend Joshua's EvaluationMetric class
+ * - Use of a length penalty to prevent chrF to prefer too long (with beta %gt; 1) or too short (with beta < 1) translations
+ * - Use of hash tables for efficient n-gram matching
- *
++ *
+ * The metric has 2 parameters:
+ * - Beta. It assigns beta times more weight to recall than to precision. By default 1.
+ * Although for evaluation the best correlation was found with beta=3, we've found the
+ * best results for tuning so far with beta=1
+ * - Max-ngram. Maximum n-gram length (characters). By default 6.
- *
++ *
+ * If you use this metric in your research please cite [2].
- *
++ *
+ * [1] Maja Popovic. 2015. chrF: character n-gram F-score for automatic MT evaluation.
+ * In Proceedings of the Tenth Workshop on Statistical Machine Translation. Lisbon, Portugal, pages 392\u2013395.
- * [2] V�ctor S�nchez Cartagena and Antonio Toral. 2016.
++ * [2] V�ctor S�nchez Cartagena and Antonio Toral. 2016.
+ * Abu-MaTran at WMT 2016 Translation Task: Deep Learning, Morphological Segmentation and Tuning on Character Sequences.
+ * In Proceedings of the First Conference on Machine Translation (WMT16). Berlin, Germany.
+
+ * @author Antonio Toral
+ */
+public class CHRF extends EvaluationMetric {
+ private static final Logger logger = Logger.getLogger(CHRF.class.getName());
+
+ protected double beta = 1;
+ protected double factor;
+ protected int maxGramLength = 6; // The maximum n-gram we care about
+ //private double[] nGramWeights; //TODO to weight them differently
-
++
+ //private String metricName;
+ //private boolean toBeMinimized;
+ //private int suffStatsCount;
-
++
+
+ public CHRF()
+ {
+ this(1, 6);
+ }
-
++
+ public CHRF(String[] CHRF_options)
+ {
+ //
+ //
+ // process the Metric_options array
+ //
+ //
+ this(Double.parseDouble(CHRF_options[0]), Integer.parseInt(CHRF_options[1]));
+ }
-
- public CHRF(double bt, int mxGrmLn){
++
++ public CHRF(double bt, int mxGrmLn){
+ if (bt > 0) {
+ beta = bt;
+ } else {
+ logger.severe("Beta must be positive");
+ System.exit(1);
+ }
-
++
+ if (mxGrmLn >= 1) {
+ maxGramLength = mxGrmLn;
+ } else {
+ logger.severe("Maximum gram length must be positive");
+ System.exit(1);
+ }
-
++
+ initialize(); // set the data members of the metric
+ }
+
++ @Override
+ protected void initialize()
+ {
+ metricName = "CHRF";
+ toBeMinimized = false;
+ suffStatsCount = 4 * maxGramLength;
- factor = Math.pow(beta, 2);
++ factor = Math.pow(beta, 2);
+ }
-
++
++ @Override
+ public double bestPossibleScore() { return 100.0; }
-
++
++ @Override
+ public double worstPossibleScore() { return 0.0; }
+
+ protected String separateCharacters(String s)
+ {
- String s_chars = "";
++ String s_chars = "";
+ //alternative implementation (the one below seems more robust)
+ /*for (int i = 0; i < s.length(); i++) {
+ if (s.charAt(i) == ' ') continue;
- s_chars += s.charAt(i) + " ";
++ s_chars += s.charAt(i) + " ";
+ }
+ System.out.println("CHRF separate chars1: " + s_chars);*/
+
+ String[] words = s.split("\\s+");
+ for (String w: words) {
+ for (int i = 0; i<w.length(); i++)
+ s_chars += w.charAt(i);
+ }
+
+ //System.out.println("CHRF separate chars: " + s_chars);
+ return s_chars;
+ }
+
-
++
+ protected HashMap<String, Integer>[] getGrams(String s)
+ {
+ HashMap<String, Integer>[] grams = new HashMap[1 + maxGramLength];
+ grams[0] = null;
+ for (int n = 1; n <= maxGramLength; ++n) {
+ grams[n] = new HashMap<String, Integer>();
+ }
+
-
++
+ for (int n=1; n<=maxGramLength; n++){
+ String gram = "";
+ for (int i = 0; i < s.length() - n + 1; i++){
+ gram = s.substring(i, i+n);
+ if(grams[n].containsKey(gram)){
+ int old_count = grams[n].get(gram);
+ grams[n].put(gram, old_count+1);
+ } else {
+ grams[n].put(gram, 1);
+ }
+ }
-
++
+ }
+
+ /* debugging
+ String key, value;
+ for (int n=1; n<=maxGramLength; n++){
+ System.out.println("Grams of order " + n);
+ for (String gram: grams[n].keySet()){
+ key = gram.toString();
+ value = grams[n].get(gram).toString();
- System.out.println(key + " " + value);
++ System.out.println(key + " " + value);
+ }
+ }*/
-
++
+ return grams;
+ }
+
-
++
+ protected int[] candRefErrors(HashMap<String, Integer> ref, HashMap<String, Integer> cand)
+ {
+ int[] to_return = {0,0};
+ String gram;
- int cand_grams = 0, ref_grams = 0;
++ int cand_grams = 0;
+ int candGramCount = 0, refGramCount = 0;
+ int errors = 0;
- double result = 0;
- String not_found = "";
-
-
++
+ Iterator<String> it = (cand.keySet()).iterator();
+
+ while (it.hasNext()) {
+ gram = it.next();
+ candGramCount = cand.get(gram);
+ cand_grams += candGramCount;
+ if (ref.containsKey(gram)) {
+ refGramCount = ref.get(gram);
- ref_grams += refGramCount;
+ if(candGramCount>refGramCount){
+ int error_here = candGramCount - refGramCount;
+ errors += error_here;
- not_found += gram + " (" + error_here + " times) ";
+ }
+ } else {
+ refGramCount = 0;
+ errors += candGramCount;
- not_found += gram + " ";
- }
++ }
+ }
-
++
+ //System.out.println(" Ngrams not found: " + not_found);
-
++
+ to_return[0] = cand_grams;
+ to_return[1] = errors;
-
++
+ return to_return;
+ }
-
++
++ @Override
+ public int[] suffStats(String cand_str, int i) //throws Exception
+ {
+ int[] stats = new int[suffStatsCount];
+
- double[] precisions = new double[maxGramLength+1];
- double[] recalls = new double[maxGramLength+1];
-
+ //TODO check unicode chars correctly split
+ String cand_char = separateCharacters(cand_str);
+ String ref_char = separateCharacters(refSentences[i][0]);
-
++
+ HashMap<String, Integer>[] grams_cand = getGrams(cand_char);
+ HashMap<String, Integer>[] grams_ref = getGrams(ref_char);
-
++
+ for (int n = 1; n <= maxGramLength; ++n) {
+ //System.out.println("Calculating precision...");
+ int[] precision_vals = candRefErrors(grams_ref[n], grams_cand[n]);
+ //System.out.println(" length: " + precision_vals[0] + ", errors: " + precision_vals[1]);
+ //System.out.println("Calculating recall...");
+ int[] recall_vals = candRefErrors(grams_cand[n], grams_ref[n]);
+ //System.out.println(" length: " + recall_vals[0] + ", errors: " + recall_vals[1]);
-
++
+ stats[4*(n-1)] = precision_vals[0]; //cand_grams
+ stats[4*(n-1)+1] = precision_vals[1]; //errors (precision)
+ stats[4*(n-1)+2] = recall_vals[0]; //ref_grams
+ stats[4*(n-1)+3] = recall_vals[1]; //errors (recall)
+ }
+
+ return stats;
+ }
+
+
++ @Override
+ public double score(int[] stats)
+ {
+ int precision_ngrams, recall_ngrams, precision_errors, recall_errors;
+ double[] precisions = new double[maxGramLength+1];
+ double[] recalls = new double[maxGramLength+1];
+ double[] fs = new double[maxGramLength+1];
+ //double[] scs = new double[maxGramLength+1];
- double totalPrecision = 0, totalRecall = 0, totalF = 0, totalSC = 0;
++ double totalF = 0, totalSC = 0;
+ double lp = 1;
-
++
+ if (stats.length != suffStatsCount) {
+ System.out.println("Mismatch between stats.length and suffStatsCount (" + stats.length + " vs. " + suffStatsCount + ") in NewMetric.score(int[])");
+ System.exit(1);
+ }
+
+ for (int n = 1; n <= maxGramLength; n++) {
+ precision_ngrams = stats[4 * (n - 1)];
+ precision_errors = stats[4 * (n - 1) + 1];
+ recall_ngrams = stats[4 * (n - 1) + 2];
+ recall_errors = stats[4 * (n - 1) + 3];
+
+ if (precision_ngrams != 0)
+ precisions[n] = 100 - 100*precision_errors/ (double)precision_ngrams;
+ else precisions[n] = 0;
-
++
+ if (recall_ngrams != 0)
+ recalls[n] = 100 - 100*recall_errors/ (double)recall_ngrams;
+ else
+ recalls[n] = 0;
-
++
+ if(precisions[n] != 0 || recalls[n] != 0)
+ fs[n] = (1+factor) * recalls[n] * precisions[n] / (factor * precisions[n] + recalls[n]);
+ else
+ fs[n] = 0;
-
++
+ //System.out.println("Precision (n=" + n + "): " + precisions[n]);
+ //System.out.println("Recall (n=" + n + "): " + recalls[n]);
+ //System.out.println("F (n=" + n + "): " + fs[n]);
+
- totalPrecision += (1/(double)maxGramLength) * precisions[n];
- totalRecall += (1/(double)maxGramLength) * recalls[n];
+ totalF += (1/(double)maxGramLength) * fs[n];
+ }
+
+ //length penalty
- if (beta>1){ //penalise long translations
++ if (beta>1){ //penalise long translations
+ lp = Math.min(1, stats[2]/(double)stats[0]);
+ } else if (beta < 1){ //penalise short translations
+ lp = Math.min(1, stats[0]/(double)stats[2]);
+ }
+ totalSC = totalF*lp;
-
++
+ //System.out.println("Precision (total): " + totalPrecision);
+ //System.out.println("Recall (total):" + totalRecall);
+ //System.out.println("F (total): " + totalF);
-
++
+ return totalSC;
+ }
+
+
++ @Override
+ public void printDetailedScore_fromStats(int[] stats, boolean oneLiner)
+ {
+ System.out.println(metricName + " = " + score(stats));
+
+ //
+ //
+ // optional (for debugging purposes)
+ //
+ //
+ }
+
+}
+
[14/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
index e88f00a,0000000..5c75188
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/Translation.java
@@@ -1,239 -1,0 +1,238 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
- import static java.util.Arrays.asList;
+import static org.apache.joshua.decoder.StructuredTranslationFactory.fromViterbiDerivation;
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiFeatures;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiWordAlignments;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.StringWriter;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.lm.StateMinimizingLanguageModel;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor;
+import org.apache.joshua.decoder.io.DeNormalize;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class represents translated input objects (sentences or lattices). It is aware of the source
+ * sentence and id and contains the decoded hypergraph. Translation objects are returned by
+ * DecoderTask instances to the InputHandler, where they are assembled in order for output.
- *
++ *
+ * @author Matt Post post@cs.jhu.edu
+ * @author Felix Hieber fhieber@amazon.com
+ */
+
+public class Translation {
+ private static final Logger LOG = LoggerFactory.getLogger(Translation.class);
+ private final Sentence source;
+
+ /**
+ * This stores the output of the translation so we don't have to hold onto the hypergraph while we
+ * wait for the outputs to be assembled.
+ */
+ private String output = null;
+
+ /**
+ * Stores the list of StructuredTranslations.
+ * If joshuaConfig.topN == 0, will only contain the Viterbi translation.
+ * Else it will use KBestExtractor to populate this list.
+ */
+ private List<StructuredTranslation> structuredTranslations = null;
-
- public Translation(Sentence source, HyperGraph hypergraph,
++
++ public Translation(Sentence source, HyperGraph hypergraph,
+ List<FeatureFunction> featureFunctions, JoshuaConfiguration joshuaConfiguration) {
+ this.source = source;
-
++
+ /**
+ * Structured output from Joshua provides a way to programmatically access translation results
+ * from downstream applications, instead of writing results as strings to an output buffer.
+ */
+ if (joshuaConfiguration.use_structured_output) {
-
++
+ if (joshuaConfiguration.topN == 0) {
+ /*
+ * Obtain Viterbi StructuredTranslation
+ */
+ StructuredTranslation translation = fromViterbiDerivation(source, hypergraph, featureFunctions);
+ this.output = translation.getTranslationString();
+ structuredTranslations = Collections.singletonList(translation);
-
++
+ } else {
+ /*
+ * Get K-Best list of StructuredTranslations
+ */
+ final KBestExtractor kBestExtractor = new KBestExtractor(source, featureFunctions, Decoder.weights, false, joshuaConfiguration);
+ structuredTranslations = kBestExtractor.KbestExtractOnHG(hypergraph, joshuaConfiguration.topN);
+ if (structuredTranslations.isEmpty()) {
+ structuredTranslations = Collections
+ .singletonList(StructuredTranslationFactory.fromEmptyOutput(source));
+ this.output = "";
+ } else {
+ this.output = structuredTranslations.get(0).getTranslationString();
+ }
+ // TODO: We omit the BLEU rescoring for now since it is not clear whether it works at all and what the desired output is below.
+ }
+
+ } else {
+
+ StringWriter sw = new StringWriter();
+ BufferedWriter out = new BufferedWriter(sw);
+
+ try {
-
++
+ if (hypergraph != null) {
-
++
+ long startTime = System.currentTimeMillis();
+
+ if (joshuaConfiguration.topN == 0) {
+
+ /* construct Viterbi output */
+ final String best = getViterbiString(hypergraph);
+
+ LOG.info("Translation {}: {} {}", source.id(), hypergraph.goalNode.getScore(), best);
+
+ /*
+ * Setting topN to 0 turns off k-best extraction, in which case we need to parse through
+ * the output-string, with the understanding that we can only substitute variables for the
+ * output string, sentence number, and model score.
+ */
+ String translation = joshuaConfiguration.outputFormat
+ .replace("%s", removeSentenceMarkers(best))
+ .replace("%S", DeNormalize.processSingleLine(best))
+ .replace("%c", String.format("%.3f", hypergraph.goalNode.getScore()))
+ .replace("%i", String.format("%d", source.id()));
+
+ if (joshuaConfiguration.outputFormat.contains("%a")) {
+ translation = translation.replace("%a", getViterbiWordAlignments(hypergraph));
+ }
+
+ if (joshuaConfiguration.outputFormat.contains("%f")) {
+ final FeatureVector features = getViterbiFeatures(hypergraph, featureFunctions, source);
+ translation = translation.replace("%f", features.textFormat());
+ }
+
+ out.write(translation);
+ out.newLine();
+
+ } else {
+
+ final KBestExtractor kBestExtractor = new KBestExtractor(
+ source, featureFunctions, Decoder.weights, false, joshuaConfiguration);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+
+ if (joshuaConfiguration.rescoreForest) {
+ final int bleuFeatureHash = hashFeature("BLEU");
+ Decoder.weights.add(bleuFeatureHash, joshuaConfiguration.rescoreForestWeight);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+
+ Decoder.weights.add(bleuFeatureHash, -joshuaConfiguration.rescoreForestWeight);
+ kBestExtractor.lazyKBestExtractOnHG(hypergraph, joshuaConfiguration.topN, out);
+ }
+ }
+
- float seconds = (float) (System.currentTimeMillis() - startTime) / 1000.0f;
++ float seconds = (System.currentTimeMillis() - startTime) / 1000.0f;
+ LOG.info("Input {}: {}-best extraction took {} seconds", id(),
+ joshuaConfiguration.topN, seconds);
+
+ } else {
-
++
+ // Failed translations and blank lines get empty formatted outputs
+ out.write(getFailedTranslationOutput(source, joshuaConfiguration));
+ out.newLine();
-
++
+ }
+
+ out.flush();
-
++
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ this.output = sw.toString();
+
+ }
-
++
+ // remove state from StateMinimizingLanguageModel instances in features.
+ destroyKenLMStates(featureFunctions);
+
+ }
+
+ public Sentence getSourceSentence() {
+ return this.source;
+ }
+
+ public int id() {
+ return source.id();
+ }
+
+ @Override
+ public String toString() {
+ return output;
+ }
-
++
+ private String getFailedTranslationOutput(final Sentence source, final JoshuaConfiguration joshuaConfiguration) {
+ return joshuaConfiguration.outputFormat
+ .replace("%s", source.source())
+ .replace("%e", "")
+ .replace("%S", "")
+ .replace("%t", "()")
+ .replace("%i", Integer.toString(source.id()))
+ .replace("%f", "")
+ .replace("%c", "0.000");
+ }
-
++
+ /**
+ * Returns the StructuredTranslations
+ * if JoshuaConfiguration.use_structured_output == True.
+ * @throws RuntimeException if JoshuaConfiguration.use_structured_output == False.
+ * @return List of StructuredTranslations.
+ */
+ public List<StructuredTranslation> getStructuredTranslations() {
+ if (structuredTranslations == null) {
+ throw new RuntimeException(
+ "No StructuredTranslation objects created. You should set JoshuaConfigration.use_structured_output = true");
+ }
+ return structuredTranslations;
+ }
-
++
+ /**
+ * KenLM hack. If using KenLMFF, we need to tell KenLM to delete the pool used to create chart
+ * objects for this sentence.
+ */
+ private void destroyKenLMStates(final List<FeatureFunction> featureFunctions) {
+ for (FeatureFunction feature : featureFunctions) {
+ if (feature instanceof StateMinimizingLanguageModel) {
+ ((StateMinimizingLanguageModel) feature).destroyPool(getSourceSentence().id());
+ break;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
index 8b5c81a,0000000..0e5139a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
@@@ -1,474 -1,0 +1,438 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.chart_parser;
+
+import java.util.ArrayList;
- import java.util.HashMap;
+import java.util.List;
- import java.util.Map;
+
- import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleCollection;
+import org.apache.joshua.decoder.ff.tm.Trie;
+import org.apache.joshua.decoder.segment_file.Token;
+import org.apache.joshua.lattice.Arc;
+import org.apache.joshua.lattice.Lattice;
+import org.apache.joshua.lattice.Node;
+import org.apache.joshua.util.ChartSpan;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The DotChart handles Earley-style implicit binarization of translation rules.
- *
++ *
+ * The {@link DotNode} object represents the (possibly partial) application of a synchronous rule.
+ * The implicit binarization is maintained with a pointer to the {@link Trie} node in the grammar,
+ * for easy retrieval of the next symbol to be matched. At every span (i,j) of the input sentence,
+ * every incomplete DotNode is examined to see whether it (a) needs a terminal and matches against
+ * the final terminal of the span or (b) needs a nonterminal and matches against a completed
+ * nonterminal in the main chart at some split point (k,j).
- *
++ *
+ * Once a rule is completed, it is entered into the {@link DotChart}. {@link DotCell} objects are
+ * used to group completed DotNodes over a span.
- *
++ *
+ * There is a separate DotChart for every grammar.
- *
++ *
+ * @author Zhifei Li, <zh...@gmail.com>
+ * @author Matt Post <po...@cs.jhu.edu>
+ * @author Kristy Hollingshead Seitz
+ */
+class DotChart {
+
+ // ===============================================================
+ // Static fields
+ // ===============================================================
+
+ private static final Logger LOG = LoggerFactory.getLogger(DotChart.class);
+
+
+ // ===============================================================
+ // Package-protected instance fields
+ // ===============================================================
+ /**
+ * Two-dimensional chart of cells. Some cells might be null. This could definitely be represented
+ * more efficiently, since only the upper half of this triangle is every used.
+ */
+ private final ChartSpan<DotCell> dotcells;
+
+ public DotCell getDotCell(int i, int j) {
+ return dotcells.get(i, j);
+ }
+
+ // ===============================================================
+ // Private instance fields (maybe could be protected instead)
+ // ===============================================================
+
+ /**
+ * CKY+ style parse chart in which completed span entries are stored.
+ */
+ private final Chart dotChart;
+
+ /**
+ * Translation grammar which contains the translation rules.
+ */
+ private final Grammar pGrammar;
+
+ /* Length of input sentence. */
+ private final int sentLen;
+
+ /* Represents the input sentence being translated. */
+ private final Lattice<Token> input;
+
+ // ===============================================================
+ // Constructors
+ // ===============================================================
+
+ // TODO: Maybe this should be a non-static inner class of Chart. That would give us implicit
+ // access to all the arguments of this constructor. Though we would need to take an argument, i,
+ // to know which Chart.this.grammars[i] to use.
+
+ /**
+ * Constructs a new dot chart from a specified input lattice, a translation grammar, and a parse
+ * chart.
- *
++ *
+ * @param input A lattice which represents an input sentence.
+ * @param grammar A translation grammar.
+ * @param chart A CKY+ style chart in which completed span entries are stored.
+ */
+ public DotChart(Lattice<Token> input, Grammar grammar, Chart chart) {
+
+ this.dotChart = chart;
+ this.pGrammar = grammar;
+ this.input = input;
+ this.sentLen = input.size();
+ this.dotcells = new ChartSpan<>(sentLen, null);
+
+ seed();
+ }
+
+ /**
+ * Add initial dot items: dot-items pointer to the root of the grammar trie.
+ */
+ void seed() {
+ for (int j = 0; j <= sentLen - 1; j++) {
+ if (pGrammar.hasRuleForSpan(j, j, input.distance(j, j))) {
+ if (null == pGrammar.getTrieRoot()) {
+ throw new RuntimeException("trie root is null");
+ }
+ addDotItem(pGrammar.getTrieRoot(), j, j, null, null, new SourcePath());
+ }
+ }
+ }
+
+ /**
+ * This function computes all possible expansions of all rules over the provided span (i,j). By
+ * expansions, we mean the moving of the dot forward (from left to right) over a nonterminal or
+ * terminal symbol on the rule's source side.
- *
++ *
+ * There are two kinds of expansions:
- *
++ *
+ * <ol>
+ * <li>Expansion over a nonterminal symbol. For this kind of expansion, a rule has a dot
+ * immediately prior to a source-side nonterminal. The main Chart is consulted to see whether
+ * there exists a completed nonterminal with the same label. If so, the dot is advanced.
- *
++ *
+ * Discovering nonterminal expansions is a matter of enumerating all split points k such that i <
+ * k and k < j. The nonterminal symbol must exist in the main Chart over (k,j).
- *
++ *
+ * <li>Expansion over a terminal symbol. In this case, expansion is a simple matter of determing
+ * whether the input symbol at position j (the end of the span) matches the next symbol in the
+ * rule. This is equivalent to choosing a split point k = j - 1 and looking for terminal symbols
+ * over (k,j). Note that phrases in the input rule are handled one-by-one as we consider longer
+ * spans.
+ * </ol>
+ */
+ void expandDotCell(int i, int j) {
+ if (LOG.isDebugEnabled())
+ LOG.debug("Expanding dot cell ({}, {})", i, j);
+
+ /*
+ * (1) If the dot is just to the left of a non-terminal variable, we look for theorems or axioms
+ * in the Chart that may apply and extend the dot position. We look for existing axioms over all
+ * spans (k,j), i < k < j.
+ */
+ for (int k = i + 1; k < j; k++) {
+ extendDotItemsWithProvedItems(i, k, j, false);
+ }
+
+ /*
+ * (2) If the the dot-item is looking for a source-side terminal symbol, we simply match against
+ * the input sentence and advance the dot.
+ */
+ Node<Token> node = input.getNode(j - 1);
+ for (Arc<Token> arc : node.getOutgoingArcs()) {
+
+ int last_word = arc.getLabel().getWord();
+ int arc_len = arc.getHead().getNumber() - arc.getTail().getNumber();
+
+ // int last_word=foreign_sent[j-1]; // input.getNode(j-1).getNumber(); //
+
+ if (null != dotcells.get(i, j - 1)) {
+ // dotitem in dot_bins[i][k]: looking for an item in the right to the dot
+
+
+ for (DotNode dotNode : dotcells.get(i, j - 1).getDotNodes()) {
+
+ // String arcWord = Vocabulary.word(last_word);
+ // Assert.assertFalse(arcWord.endsWith("]"));
+ // Assert.assertFalse(arcWord.startsWith("["));
+ // logger.info("DotChart.expandDotCell: " + arcWord);
+
+ // List<Trie> child_tnodes = ruleMatcher.produceMatchingChildTNodesTerminalevel(dotNode,
+ // last_word);
+
- List<Trie> child_tnodes = null;
-
+ Trie child_node = dotNode.trieNode.match(last_word);
+ if (null != child_node) {
+ addDotItem(child_node, i, j - 1 + arc_len, dotNode.antSuperNodes, null,
+ dotNode.srcPath.extend(arc));
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * note: (i,j) is a non-terminal, this cannot be a cn-side terminal, which have been handled in
+ * case2 of dotchart.expand_cell add dotitems that start with the complete super-items in
+ * cell(i,j)
+ */
+ void startDotItems(int i, int j) {
+ extendDotItemsWithProvedItems(i, i, j, true);
+ }
+
+ // ===============================================================
+ // Private methods
+ // ===============================================================
+
+ /**
+ * Attempt to combine an item in the dot chart with an item in the main chart to create a new item
+ * in the dot chart. The DotChart item is a {@link DotNode} begun at position i with the dot
+ * currently at position k, that is, a partially-applied rule.
- *
++ *
+ * In other words, this method looks for (proved) theorems or axioms in the completed chart that
+ * may apply and extend the dot position.
- *
++ *
+ * @param i Start index of a dot chart item
+ * @param k End index of a dot chart item; start index of a completed chart item
+ * @param j End index of a completed chart item
+ * @param skipUnary if true, don't extend unary rules
+ */
+ private void extendDotItemsWithProvedItems(int i, int k, int j, boolean skipUnary) {
+ if (this.dotcells.get(i, k) == null || this.dotChart.getCell(k, j) == null) {
+ return;
+ }
+
+ // complete super-items (items over the same span with different LHSs)
+ List<SuperNode> superNodes = new ArrayList<>(this.dotChart.getCell(k, j).getSortedSuperItems().values());
+
+ /* For every partially complete item over (i,k) */
+ for (DotNode dotNode : dotcells.get(i, k).dotNodes) {
+ /* For every completed nonterminal in the main chart */
+ for (SuperNode superNode : superNodes) {
+
+ // String arcWord = Vocabulary.word(superNode.lhs);
+ // logger.info("DotChart.extendDotItemsWithProvedItems: " + arcWord);
+ // Assert.assertTrue(arcWord.endsWith("]"));
+ // Assert.assertTrue(arcWord.startsWith("["));
+
+ /*
+ * Regular Expression matching allows for a regular-expression style rules in the grammar,
+ * which allows for a very primitive treatment of morphology. This is an advanced,
+ * undocumented feature that introduces a complexity, in that the next "word" in the grammar
+ * rule might match more than one outgoing arc in the grammar trie.
+ */
+ Trie child_node = dotNode.getTrieNode().match(superNode.lhs);
+ if (child_node != null) {
+ if ((!skipUnary) || (child_node.hasExtensions())) {
+ addDotItem(child_node, i, j, dotNode.getAntSuperNodes(), superNode, dotNode
+ .getSourcePath().extendNonTerminal());
+ }
+ }
+ }
+ }
+ }
+
- /*
- * We introduced the ability to have regular expressions in rules for matching against terminals.
- * For example, you could have the rule
- *
- * <pre> [X] ||| l?s herman?s ||| siblings </pre>
- *
- * When this is enabled for a grammar, we need to test against *all* (positive) outgoing arcs of
- * the grammar trie node to see if any of them match, and then return the whole set. This is quite
- * expensive, which is why you should only enable regular expressions for small grammars.
- */
-
- private ArrayList<Trie> matchAll(DotNode dotNode, int wordID) {
- ArrayList<Trie> trieList = new ArrayList<>();
- HashMap<Integer, ? extends Trie> childrenTbl = dotNode.trieNode.getChildren();
-
- if (childrenTbl != null && wordID >= 0) {
- // get all the extensions, map to string, check for *, build regexp
- for (Map.Entry<Integer, ? extends Trie> entry : childrenTbl.entrySet()) {
- Integer arcID = entry.getKey();
- if (arcID == wordID) {
- trieList.add(entry.getValue());
- } else {
- String arcWord = Vocabulary.word(arcID);
- if (Vocabulary.word(wordID).matches(arcWord)) {
- trieList.add(entry.getValue());
- }
- }
- }
- }
- return trieList;
- }
-
-
+ /**
+ * Creates a {@link DotNode} and adds it into the {@link DotChart} at the correct place. These
- * are (possibly incomplete) rule applications.
- *
++ * are (possibly incomplete) rule applications.
++ *
+ * @param tnode the trie node pointing to the location ("dot") in the grammar trie
+ * @param i
+ * @param j
+ * @param antSuperNodesIn the supernodes representing the rule's tail nodes
+ * @param curSuperNode the lefthand side of the rule being created
+ * @param srcPath the path taken through the input lattice
+ */
+ private void addDotItem(Trie tnode, int i, int j, ArrayList<SuperNode> antSuperNodesIn,
+ SuperNode curSuperNode, SourcePath srcPath) {
+ ArrayList<SuperNode> antSuperNodes = new ArrayList<>();
+ if (antSuperNodesIn != null) {
+ antSuperNodes.addAll(antSuperNodesIn);
+ }
+ if (curSuperNode != null) {
+ antSuperNodes.add(curSuperNode);
+ }
+
+ DotNode item = new DotNode(i, j, tnode, antSuperNodes, srcPath);
+ if (dotcells.get(i, j) == null) {
+ dotcells.set(i, j, new DotCell());
+ }
+ dotcells.get(i, j).addDotNode(item);
+ dotChart.nDotitemAdded++;
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Add a dotitem in cell ({}, {}), n_dotitem={}, {}", i, j,
+ dotChart.nDotitemAdded, srcPath);
+
+ RuleCollection rules = tnode.getRuleCollection();
+ if (rules != null) {
+ for (Rule r : rules.getRules()) {
+ // System.out.println("rule: "+r.toString());
+ LOG.debug("{}", r);
+ }
+ }
+ }
+ }
+
+ // ===============================================================
+ // Package-protected classes
+ // ===============================================================
+
+ /**
+ * A DotCell groups together DotNodes that have been applied over a particular span. A DotNode, in
+ * turn, is a partially-applied grammar rule, represented as a pointer into the grammar trie
+ * structure.
+ */
+ static class DotCell {
+
+ // Package-protected fields
+ private final List<DotNode> dotNodes = new ArrayList<>();
+
+ public List<DotNode> getDotNodes() {
+ return dotNodes;
+ }
+
+ private void addDotNode(DotNode dt) {
+ /*
+ * if(l_dot_items==null) l_dot_items= new ArrayList<DotItem>();
+ */
+ dotNodes.add(dt);
+ }
+ }
+
+ /**
+ * A DotNode represents the partial application of a rule rooted to a particular span (i,j). It
+ * maintains a pointer to the trie node in the grammar for efficient mapping.
+ */
+ static class DotNode {
+
+ private final int i;
+ private final int j;
+ private Trie trieNode = null;
-
++
+ /* A list of grounded (over a span) nonterminals that have been crossed in traversing the rule */
+ private ArrayList<SuperNode> antSuperNodes = null;
-
++
+ /* The source lattice cost of applying the rule */
+ private final SourcePath srcPath;
+
+ @Override
+ public String toString() {
+ int size = 0;
+ if (trieNode != null && trieNode.getRuleCollection() != null)
+ size = trieNode.getRuleCollection().getRules().size();
+ return String.format("DOTNODE i=%d j=%d #rules=%d #tails=%d", i, j, size, antSuperNodes.size());
+ }
-
++
+ /**
+ * Initialize a dot node with the span, grammar trie node, list of supernode tail pointers, and
+ * the lattice sourcepath.
- *
++ *
+ * @param i
+ * @param j
+ * @param trieNode
+ * @param antSuperNodes
+ * @param srcPath
+ */
+ public DotNode(int i, int j, Trie trieNode, ArrayList<SuperNode> antSuperNodes, SourcePath srcPath) {
+ this.i = i;
+ this.j = j;
+ this.trieNode = trieNode;
+ this.antSuperNodes = antSuperNodes;
+ this.srcPath = srcPath;
+ }
+
++ @Override
+ public boolean equals(Object obj) {
+ if (obj == null)
+ return false;
+ if (!this.getClass().equals(obj.getClass()))
+ return false;
+ DotNode state = (DotNode) obj;
+
+ /*
+ * Technically, we should be comparing the span inforamtion as well, but that would require us
+ * to store it, increasing memory requirements, and we should be able to guarantee that we
+ * won't be comparing DotNodes across spans.
+ */
+ // if (this.i != state.i || this.j != state.j)
+ // return false;
+
+ return this.trieNode == state.trieNode;
+
+ }
+
+ /**
+ * Technically the hash should include the span (i,j), but since DotNodes are grouped by span,
+ * this isn't necessary, and we gain something by not having to store the span.
+ */
++ @Override
+ public int hashCode() {
+ return this.trieNode.hashCode();
+ }
+
+ // convenience function
+ public boolean hasRules() {
+ return getTrieNode().getRuleCollection() != null && getTrieNode().getRuleCollection().getRules().size() != 0;
+ }
-
++
+ public RuleCollection getRuleCollection() {
+ return getTrieNode().getRuleCollection();
+ }
+
+ public Trie getTrieNode() {
+ return trieNode;
+ }
+
+ public SourcePath getSourcePath() {
+ return srcPath;
+ }
+
+ public ArrayList<SuperNode> getAntSuperNodes() {
+ return antSuperNodes;
+ }
+
+ public int begin() {
+ return i;
+ }
-
++
+ public int end() {
+ return j;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
index 9338b0d,0000000..d9b894c
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
@@@ -1,215 -1,0 +1,214 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff;
+
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.decoder.chart_parser.SourcePath;
+import org.apache.joshua.decoder.ff.state_maintenance.DPState;
+import org.apache.joshua.decoder.ff.state_maintenance.NgramDPState;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.io.LineReader;
+
+/***
+ * The RuleBigram feature is an indicator feature that counts target word bigrams that are created when
+ * a rule is applied. It accepts three parameters:
+ *
+ * -vocab /path/to/vocab
+ *
+ * The path to a vocabulary, where each line is of the format ID WORD COUNT.
+ *
+ * -threshold N
+ *
+ * Mask to UNK all words whose COUNT is less than N.
+ *
+ * -top-n N
+ *
+ * Only use the top N words.
+ */
+
+public class TargetBigram extends StatefulFF {
+
+ private HashSet<String> vocab = null;
+ private int maxTerms = 1000000;
+ private int threshold = 0;
+
+ public TargetBigram(FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "TargetBigram", args, config);
+
+ if (parsedArgs.containsKey("threshold"))
+ threshold = Integer.parseInt(parsedArgs.get("threshold"));
+
+ if (parsedArgs.containsKey("top-n"))
+ maxTerms = Integer.parseInt(parsedArgs.get("top-n"));
+
+ if (parsedArgs.containsKey("vocab")) {
+ loadVocab(parsedArgs.get("vocab"));
+ }
+ }
+
+ /**
+ * Load vocabulary items passing the 'threshold' and 'top-n' filters.
+ *
+ * @param filename
+ */
+ private void loadVocab(String filename) {
+ this.vocab = new HashSet<>();
+ this.vocab.add("<s>");
+ this.vocab.add("</s>");
- try {
- LineReader lineReader = new LineReader(filename);
++ try(LineReader lineReader = new LineReader(filename);) {
+ for (String line: lineReader) {
+ if (lineReader.lineno() > maxTerms)
+ break;
+
+ String[] tokens = line.split("\\s+");
+ String word = tokens[1];
+ int count = Integer.parseInt(tokens[2]);
+
+ if (count >= threshold)
+ vocab.add(word);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(String.format(
+ "* FATAL: couldn't load TargetBigram vocabulary '%s'", filename), e);
+ }
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int spanStart, int spanEnd,
+ SourcePath sourcePath, Sentence sentence, Accumulator acc) {
+
+ int[] enWords = rule.getTarget();
+
+ int left = -1;
+ int right = -1;
+
+ List<String> currentNgram = new LinkedList<>();
+ for (int curID : enWords) {
+ if (FormatUtils.isNonterminal(curID)) {
+ int index = -(curID + 1);
+ NgramDPState state = (NgramDPState) tailNodes.get(index).getDPState(stateIndex);
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
+
+ // Left context.
+ for (int token : leftContext) {
+ currentNgram.add(getWord(token));
+ if (left == -1)
+ left = token;
+ right = token;
+ if (currentNgram.size() == 2) {
+ String ngram = join(currentNgram);
+ acc.add(hashFeature(String.format("%s_%s", name, ngram)), 1);
+ // System.err.println(String.format("ADDING %s_%s", name, ngram));
+ currentNgram.remove(0);
+ }
+ }
+ // Replace right context.
+ int tSize = currentNgram.size();
+ for (int i = 0; i < rightContext.length; i++)
+ currentNgram.set(tSize - rightContext.length + i, getWord(rightContext[i]));
+
+ } else { // terminal words
+ currentNgram.add(getWord(curID));
+ if (left == -1)
+ left = curID;
+ right = curID;
+ if (currentNgram.size() == 2) {
+ String ngram = join(currentNgram);
+ acc.add(hashFeature(String.format("%s_%s", name, ngram)), 1);
+ // System.err.println(String.format("ADDING %s_%s", name, ngram));
+ currentNgram.remove(0);
+ }
+ }
+ }
+
+ // System.err.println(String.format("RULE %s -> state %s", rule.getRuleString(), state));
+ return new NgramDPState(new int[] { left }, new int[] { right });
+ }
+
+ /**
+ * Returns the word after comparing against the private vocabulary (if set).
+ *
+ * @param curID
+ * @return the word
+ */
+ private String getWord(int curID) {
+ String word = Vocabulary.word(curID);
+
+ if (vocab != null && ! vocab.contains(word)) {
+ return "UNK";
+ }
+
+ return word;
+ }
+
+ /**
+ * We don't compute a future cost.
+ */
+ @Override
+ public float estimateFutureCost(Rule rule, DPState state, Sentence sentence) {
+ return 0.0f;
+ }
+
+ /**
+ * There is nothing to be done here, since <s> and </s> are included in rules that are part
+ * of the grammar. We simply return the DP state of the tail node.
+ */
+ @Override
+ public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ return tailNode.getDPState(stateIndex);
+ }
+
+ /**
- * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
++ * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
+ */
+ @Override
+ public float estimateCost(Rule rule, Sentence sentence) {
+ return 0.0f;
+ }
+
+ /**
+ * Join a list with the _ character. I am sure this is in a library somewhere.
+ *
+ * @param list a list of strings
+ * @return the joined String
+ */
+ private String join(List<String> list) {
+ StringBuilder sb = new StringBuilder();
+ for (String item : list) {
+ sb.append(item).append("_");
+ }
+
+ return sb.substring(0, sb.length() - 1);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
index 036c4bc,0000000..f822fe4
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
@@@ -1,777 -1,0 +1,786 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.fragmentlm;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.io.StringReader;
- import java.util.*;
++import java.util.ArrayList;
++import java.util.Collection;
++import java.util.Collections;
++import java.util.HashMap;
++import java.util.HashSet;
++import java.util.Iterator;
++import java.util.List;
++import java.util.Set;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.fragmentlm.Trees.PennTreeReader;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.hypergraph.HGNode;
+import org.apache.joshua.decoder.hypergraph.HyperEdge;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Represent phrase-structure trees, with each node consisting of a label and a list of children.
+ * Borrowed from the Berkeley Parser, and extended to allow the representation of tree fragments in
+ * addition to complete trees (the BP requires terminals to be immediately governed by a
+ * preterminal). To distinguish terminals from nonterminals in fragments, the former must be
+ * enclosed in double-quotes when read in.
- *
++ *
+ * @author Dan Klein
+ * @author Matt Post post@cs.jhu.edu
+ */
+public class Tree implements Serializable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Tree.class);
+ private static final long serialVersionUID = 1L;
+
+ protected int label;
+
+ /* Marks a frontier node as a terminal (as opposed to a nonterminal). */
+ boolean isTerminal = false;
+
+ /*
+ * Marks the root and frontier nodes of a fragment. Useful for denoting fragment derivations in
+ * larger trees.
+ */
+ boolean isBoundary = false;
+
+ /* A list of the node's children. */
+ List<Tree> children;
+
+ /* The maximum distance from the root to any of the frontier nodes. */
+ int depth = -1;
+
+ /* The number of lexicalized items among the tree's frontier. */
+ private int numLexicalItems = -1;
+
+ /*
+ * This maps the flat right-hand sides of Joshua rules to the tree fragments they were derived
+ * from. It is used to lookup the fragment that language model fragments should be match against.
+ * For example, if the target (English) side of your rule is
- *
++ *
+ * [NP,1] said [SBAR,2]
- *
++ *
+ * we will retrieve the unflattened fragment
- *
++ *
+ * (S NP (VP (VBD said) SBAR))
- *
++ *
+ * which presumably was the fronter fragment used to derive the translation rule. With this in
+ * hand, we can iterate through our store of language model fragments to match them against this,
+ * following tail nodes if necessary.
+ */
+ public static final HashMap<String, String> rulesToFragmentStrings = new HashMap<>();
+
+ public Tree(String label, List<Tree> children) {
+ setLabel(label);
+ this.children = children;
+ }
+
+ public Tree(String label) {
+ setLabel(label);
+ this.children = Collections.emptyList();
+ }
+
+ public Tree(int label2, ArrayList<Tree> newChildren) {
+ this.label = label2;
+ this.children = newChildren;
+ }
+
+ public void setChildren(List<Tree> c) {
+ this.children = c;
+ }
+
+ public List<Tree> getChildren() {
+ return children;
+ }
+
+ public int getLabel() {
+ return label;
+ }
+
+ /**
+ * Computes the depth-one rule rooted at this node. If the node has no children, null is returned.
- *
++ *
+ * @return string representation of the rule
+ */
+ public String getRule() {
+ if (isLeaf()) {
+ return null;
+ }
+ StringBuilder ruleString = new StringBuilder("(" + Vocabulary.word(getLabel()));
+ for (Tree child : getChildren()) {
+ ruleString.append(" ").append(Vocabulary.word(child.getLabel()));
+ }
+ return ruleString.toString();
+ }
+
+ /*
+ * Boundary nodes are used externally to mark merge points between different fragments. This is
+ * separate from the internal ( (substitution point) denotation.
+ */
+ public boolean isBoundary() {
+ return isBoundary;
+ }
+
+ public void setBoundary(boolean b) {
+ this.isBoundary = b;
+ }
+
+ public boolean isTerminal() {
+ return isTerminal;
+ }
+
+ public boolean isLeaf() {
+ return getChildren().isEmpty();
+ }
+
+ public boolean isPreTerminal() {
+ return getChildren().size() == 1 && getChildren().get(0).isLeaf();
+ }
+
+ public List<Tree> getNonterminalYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendNonterminalYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendYield(this, yield);
+ return yield;
+ }
+
+ public List<Tree> getTerminals() {
+ List<Tree> yield = new ArrayList<>();
+ appendTerminals(this, yield);
+ return yield;
+ }
+
+ private static void appendTerminals(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendTerminals(child, yield);
+ }
+ }
+
+ /**
+ * Clone the structure of the tree.
- *
++ *
+ * @return a cloned tree
+ */
+ public Tree shallowClone() {
+ ArrayList<Tree> newChildren = new ArrayList<>(children.size());
+ for (Tree child : children) {
+ newChildren.add(child.shallowClone());
+ }
+
+ Tree newTree = new Tree(label, newChildren);
+ newTree.setIsTerminal(isTerminal());
+ newTree.setBoundary(isBoundary());
+ return newTree;
+ }
+
+ private void setIsTerminal(boolean terminal) {
+ isTerminal = terminal;
+ }
+
+ private static void appendNonterminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf() && !tree.isTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendNonterminalYield(child, yield);
+ }
+ }
+
+ private static void appendYield(Tree tree, List<Tree> yield) {
+ if (tree.isLeaf()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendYield(child, yield);
+ }
+ }
+
+ public List<Tree> getPreTerminalYield() {
+ List<Tree> yield = new ArrayList<>();
+ appendPreTerminalYield(this, yield);
+ return yield;
+ }
+
+ private static void appendPreTerminalYield(Tree tree, List<Tree> yield) {
+ if (tree.isPreTerminal()) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendPreTerminalYield(child, yield);
+ }
+ }
+
+ /**
+ * A tree is lexicalized if it has terminal nodes among the leaves of its frontier. For normal
+ * trees this is always true since they bottom out in terminals, but for fragments, this may or
+ * may not be true.
- *
++ *
+ * @return true if the tree is lexicalized
+ */
+ public boolean isLexicalized() {
+ if (this.numLexicalItems < 0) {
+ if (isTerminal())
+ this.numLexicalItems = 1;
+ else {
+ this.numLexicalItems = 0;
+ children.stream().filter(child -> child.isLexicalized())
+ .forEach(child -> this.numLexicalItems += 1);
+ }
+ }
+
+ return (this.numLexicalItems > 0);
+ }
+
+ /**
+ * The depth of a tree is the maximum distance from the root to any of the frontier nodes.
- *
++ *
+ * @return the tree depth
+ */
+ public int getDepth() {
+ if (this.depth >= 0)
+ return this.depth;
+
+ if (isLeaf()) {
+ this.depth = 0;
+ } else {
+ int maxDepth = 0;
+ for (Tree child : children) {
+ int depth = child.getDepth();
+ if (depth > maxDepth)
+ maxDepth = depth;
+ }
+ this.depth = maxDepth + 1;
+ }
+ return this.depth;
+ }
+
+ public List<Tree> getAtDepth(int depth) {
+ List<Tree> yield = new ArrayList<>();
+ appendAtDepth(depth, this, yield);
+ return yield;
+ }
+
+ private static void appendAtDepth(int depth, Tree tree, List<Tree> yield) {
+ if (depth < 0)
+ return;
+ if (depth == 0) {
+ yield.add(tree);
+ return;
+ }
+ for (Tree child : tree.getChildren()) {
+ appendAtDepth(depth - 1, child, yield);
+ }
+ }
+
+ public void setLabel(String label) {
+ if (label.length() >= 3 && label.startsWith("\"") && label.endsWith("\"")) {
+ this.isTerminal = true;
+ label = label.substring(1, label.length() - 1);
+ }
+
+ this.label = Vocabulary.id(label);
+ }
+
++ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ toStringBuilder(sb);
+ return sb.toString();
+ }
+
+ /**
+ * Removes the quotes around terminals. Note that the resulting tree could not be read back
+ * in by this class, since unquoted leaves are interpreted as nonterminals.
- *
++ *
+ * @return unquoted string
+ */
+ public String unquotedString() {
+ return toString().replaceAll("\"", "");
+ }
-
++
+ public String escapedString() {
+ return toString().replaceAll(" ", "_");
+ }
+
+ public void toStringBuilder(StringBuilder sb) {
+ if (!isLeaf())
+ sb.append('(');
+
+ if (isTerminal())
+ sb.append(String.format("\"%s\"", Vocabulary.word(getLabel())));
+ else
+ sb.append(Vocabulary.word(getLabel()));
+
+ if (!isLeaf()) {
+ for (Tree child : getChildren()) {
+ sb.append(' ');
+ child.toStringBuilder(sb);
+ }
+ sb.append(')');
+ }
+ }
+
+ /**
+ * Get the set of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
++ *
+ * @return the <code>Set</code> of all subtrees in the tree.
+ */
+ public Set<Tree> subTrees() {
+ return (Set<Tree>) subTrees(new HashSet<>());
+ }
+
+ /**
+ * Get the list of all subtrees inside the tree by returning a tree rooted at each node. These are
+ * <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
++ *
+ * @return the <code>List</code> of all subtrees in the tree.
+ */
+ public List<Tree> subTreeList() {
+ return (List<Tree>) subTrees(new ArrayList<>());
+ }
+
+ /**
+ * Add the set of all subtrees inside a tree (including the tree itself) to the given
+ * <code>Collection</code>.
- *
++ *
+ * @param n A collection of nodes to which the subtrees will be added
+ * @return The collection parameter with the subtrees added
+ */
+ public Collection<Tree> subTrees(Collection<Tree> n) {
+ n.add(this);
+ List<Tree> kids = getChildren();
+ for (Tree kid : kids) {
+ kid.subTrees(n);
+ }
+ return n;
+ }
+
+ /**
+ * Returns an iterator over the nodes of the tree. This method implements the
+ * <code>iterator()</code> method required by the <code>Collections</code> interface. It does a
+ * preorder (children after node) traversal of the tree. (A possible extension to the class at
+ * some point would be to allow different traversal orderings via variant iterators.)
- *
++ *
+ * @return An interator over the nodes of the tree
+ */
+ public TreeIterator iterator() {
+ return new TreeIterator();
+ }
+
+ private class TreeIterator implements Iterator<Tree> {
+
+ private final List<Tree> treeStack;
+
+ private TreeIterator() {
+ treeStack = new ArrayList<>();
+ treeStack.add(Tree.this);
+ }
+
++ @Override
+ public boolean hasNext() {
+ return (!treeStack.isEmpty());
+ }
+
++ @Override
+ public Tree next() {
+ int lastIndex = treeStack.size() - 1;
+ Tree tr = treeStack.remove(lastIndex);
+ List<Tree> kids = tr.getChildren();
+ // so that we can efficiently use one List, we reverse them
+ for (int i = kids.size() - 1; i >= 0; i--) {
+ treeStack.add(kids.get(i));
+ }
+ return tr;
+ }
+
+ /**
+ * Not supported
+ */
++ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+ public boolean hasUnaryChain() {
+ return hasUnaryChainHelper(this, false);
+ }
+
+ private boolean hasUnaryChainHelper(Tree tree, boolean unaryAbove) {
+ boolean result = false;
+ if (tree.getChildren().size() == 1) {
+ if (unaryAbove)
+ return true;
+ else if (tree.getChildren().get(0).isPreTerminal())
+ return false;
+ else
+ return hasUnaryChainHelper(tree.getChildren().get(0), true);
+ } else {
+ for (Tree child : tree.getChildren()) {
+ if (!child.isPreTerminal())
+ result = result || hasUnaryChainHelper(child, false);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Inserts the SOS (and EOS) symbols into a parse tree, attaching them as a left (right) sibling
+ * to the leftmost (rightmost) pre-terminal in the tree. This facilitates using trees as language
+ * models. The arguments have to be passed in to preserve Java generics, even though this is only
+ * ever used with String versions.
- *
++ *
+ * @param sos presumably "<s>"
+ * @param eos presumably "</s>"
+ */
+ public void insertSentenceMarkers(String sos, String eos) {
+ insertSentenceMarker(sos, 0);
+ insertSentenceMarker(eos, -1);
+ }
+
+ public void insertSentenceMarkers() {
+ insertSentenceMarker("<s>", 0);
+ insertSentenceMarker("</s>", -1);
+ }
+
+ /**
- *
++ *
+ * @param symbol the marker to insert
+ * @param pos the position at which to insert
+ */
+ private void insertSentenceMarker(String symbol, int pos) {
+
+ if (isLeaf() || isPreTerminal())
+ return;
+
+ List<Tree> children = getChildren();
+ int index = (pos == -1) ? children.size() - 1 : pos;
+ if (children.get(index).isPreTerminal()) {
+ if (pos == -1)
+ children.add(new Tree(symbol));
+ else
+ children.add(pos, new Tree(symbol));
+ } else {
+ children.get(index).insertSentenceMarker(symbol, pos);
+ }
+ }
+
+ /**
+ * This is a convenience function for producing a fragment from its string representation.
- *
++ *
+ * @param ptbStr input string from which to produce a fragment
+ * @return the fragment
+ */
+ public static Tree fromString(String ptbStr) {
+ PennTreeReader reader = new PennTreeReader(new StringReader(ptbStr));
+ return reader.next();
+ }
+
+ public static Tree getFragmentFromYield(String yield) {
+ String fragmentString = rulesToFragmentStrings.get(yield);
+ if (fragmentString != null)
+ return fromString(fragmentString);
+
+ return null;
+ }
+
+ public static void readMapping(String fragmentMappingFile) {
+ /* Read in the rule / fragments mapping */
- try {
- LineReader reader = new LineReader(fragmentMappingFile);
++ try (LineReader reader = new LineReader(fragmentMappingFile);) {
+ for (String line : reader) {
+ String[] fields = line.split("\\s+\\|{3}\\s+");
+ if (fields.length != 2 || !fields[0].startsWith("(")) {
+ LOG.warn("malformed line {}: {}", reader.lineno(), line);
+ continue;
+ }
+
+ rulesToFragmentStrings.put(fields[1].trim(), fields[0].trim()); // buildFragment(fields[0]));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(String.format("* WARNING: couldn't read fragment mapping file '%s'",
+ fragmentMappingFile), e);
+ }
+ LOG.info("FragmentLMFF: Read {} mappings from '{}'", rulesToFragmentStrings.size(),
+ fragmentMappingFile);
+ }
+
+ /**
+ * Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.
- *
++ *
+ * This function is like Tree#buildTree(DerivationState, int),
+ * but that one simply follows the best incoming hyperedge for each node.
- *
++ *
+ * @param rule for which corresponding internal fragment can be used to initialize the tree
+ * @param derivationStates array of state objects
+ * @param maxDepth of route through the hypergraph
- * @return the Tree
++ * @return the Tree
+ */
+ public static Tree buildTree(Rule rule, DerivationState[] derivationStates, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("buildTree({})", tree);
+ for (int i = 0; i < derivationStates.length; i++) {
+ LOG.debug(" -> {}: {}", i, derivationStates[i]);
+ }
+ }
+
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
++ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] englishInts = rule.getTarget();
+ for (int i = 0; i < englishInts.length; i++)
+ if (englishInts[i] < 0)
+ tailIndices.add(-(englishInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < derivationStates.length; i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ HyperEdge nextEdge = derivationStates[i].edge;
+ if (nextEdge != null) {
+ DerivationState[] nextStates = null;
+ if (nextEdge.getTailNodes() != null && nextEdge.getTailNodes().size() > 0) {
+ nextStates = new DerivationState[nextEdge.getTailNodes().size()];
+ for (int j = 0; j < nextStates.length; j++)
+ nextStates[j] = derivationStates[i].getChildDerivationState(nextEdge, j);
+ }
+ Tree childTree = buildTree(nextEdge.getRule(), nextStates, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ }
-
++
+ return tree;
+ }
-
++
+ /**
+ * <p>Builds a tree from the kth-best derivation state. This is done by initializing the tree with
+ * the internal fragment corresponding to the rule; this will be the top of the tree. We then
+ * recursively visit the derivation state objects, following the route through the hypergraph
+ * defined by them.</p>
- *
++ *
+ * @param derivationState array of state objects
+ * @param maxDepth of route through the hypergraph
+ * @return the Tree
+ */
+ public static Tree buildTree(DerivationState derivationState, int maxDepth) {
+ Rule rule = derivationState.edge.getRule();
-
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ return null;
+ }
+
+ tree = tree.shallowClone();
-
++
+ LOG.debug("buildTree({})", tree);
+
+ if (rule.getArity() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ /* The English side of a rule is a sequence of integers. Nonnegative integers are word
+ * indices in the Vocabulary, while negative indices are used to nonterminals. These negative
+ * indices are a *permutation* of the source side nonterminals, which contain the actual
+ * nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
++ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * the incoming DerivationState items, which are ordered by the source side.
+ */
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] targetInts = rule.getTarget();
+ for (int i = 0; i < targetInts.length; i++)
+ if (targetInts[i] < 0)
+ tailIndices.add(-(targetInts[i] + 1));
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the heads of the derivation states. Since we don't know which of the tree's
+ * frontier items are terminals and which are nonterminals, we walk through the tail nodes,
+ * and then match the label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < rule.getArity(); i++) {
+
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ DerivationState childState = derivationState.getChildDerivationState(derivationState.edge, i);
+ Tree childTree = buildTree(childState, maxDepth - 1);
+
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ }
+ }
-
++
+ return tree;
+ }
+
+ /**
+ * Takes a rule and its tail pointers and recursively constructs a tree (up to maxDepth).
- *
++ *
+ * This could be implemented by using the other buildTree() function and using the 1-best
+ * DerivationState.
- *
++ *
+ * @param rule {@link org.apache.joshua.decoder.ff.tm.Rule} to be used whilst building the tree
+ * @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s
+ * @param maxDepth to go in the tree
+ * @return shallow clone of the Tree object
+ */
+ public static Tree buildTree(Rule rule, List<HGNode> tailNodes, int maxDepth) {
+ Tree tree = getFragmentFromYield(rule.getTargetWords());
+
+ if (tree == null) {
+ tree = new Tree(String.format("(%s %s)", Vocabulary.word(rule.getLHS()), rule.getTargetWords()));
+ // System.err.println("COULDN'T FIND " + rule.getEnglishWords());
+ // System.err.println("RULE " + rule);
+ // for (Entry<String, Tree> pair: rulesToFragments.entrySet())
+ // System.err.println(" FOUND " + pair.getKey());
+
+// return null;
+ } else {
+ tree = tree.shallowClone();
+ }
+
+ if (tree != null && tailNodes != null && tailNodes.size() > 0 && maxDepth > 0) {
+ List<Tree> frontier = tree.getNonterminalYield();
+
+ ArrayList<Integer> tailIndices = new ArrayList<Integer>();
+ int[] targetInts = rule.getTarget();
+ for (int i = 0; i < targetInts.length; i++)
+ if (targetInts[i] < 0)
+ tailIndices.add(-1 * targetInts[i] - 1);
+
+ /*
+ * We now have the tree's yield. The substitution points on the yield should match the
+ * nonterminals of the tail nodes. Since we don't know which of the tree's frontier items are
+ * terminals and which are nonterminals, we walk through the tail nodes, and then match the
+ * label of each against the frontier node labels until we have a match.
+ */
+ // System.err.println(String.format("WORDS: %s\nTREE: %s", rule.getEnglishWords(), tree));
+ for (int i = 0; i < tailNodes.size(); i++) {
+
+ // String lhs = tailNodes.get(i).getLHS().replaceAll("[\\[\\]]", "");
+ // System.err.println(String.format(" %d: %s", i, lhs));
+ try {
+ Tree frontierTree = frontier.get(tailIndices.get(i));
+ frontierTree.setBoundary(true);
+
+ HyperEdge edge = tailNodes.get(i).bestHyperedge;
+ if (edge != null) {
+ Tree childTree = buildTree(edge.getRule(), edge.getTailNodes(), maxDepth - 1);
+ /* This can be null if there is no entry for the rule in the map */
+ if (childTree != null)
+ frontierTree.children = childTree.children;
+ } else {
+ frontierTree.children = tree.children;
+ }
+ } catch (IndexOutOfBoundsException e) {
+ LOG.error("ERROR at index {}", i);
+ LOG.error("RULE: {} TREE: {}", rule.getTargetWords(), tree);
+ LOG.error(" FRONTIER:");
+ for (Tree kid : frontier) {
+ LOG.error(" {}", kid);
+ }
+ throw new RuntimeException(String.format("ERROR at index %d", i), e);
+ }
+ }
+ }
+
+ return tree;
+ }
+
- public static void main(String[] args) {
- LineReader reader = new LineReader(System.in);
-
- for (String line : reader) {
- try {
- Tree tree = Tree.fromString(line);
- tree.insertSentenceMarkers();
- System.out.println(tree);
- } catch (Exception e) {
- System.out.println("");
++ public static void main(String[] args) throws IOException {
++ try (LineReader reader = new LineReader(System.in);) {
++ for (String line : reader) {
++ try {
++ Tree tree = Tree.fromString(line);
++ tree.insertSentenceMarkers();
++ System.out.println(tree);
++ } catch (Exception e) {
++ System.out.println("");
++ }
+ }
+ }
+
+ /*
+ * Tree fragment = Tree
+ * .fromString("(TOP (S (NP (DT the) (NN boy)) (VP (VBD ate) (NP (DT the) (NN food)))))");
+ * fragment.insertSentenceMarkers("<s>", "</s>");
- *
++ *
+ * System.out.println(fragment);
- *
++ *
+ * ArrayList<Tree> trees = new ArrayList<Tree>(); trees.add(Tree.fromString("(NN \"mat\")"));
+ * trees.add(Tree.fromString("(S (NP DT NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT \"the\") NN) VP)"));
+ * trees.add(Tree.fromString("(S (NP (DT the) NN) VP)"));
- *
++ *
+ * for (Tree tree : trees) { System.out.println(String.format("TREE %s DEPTH %d LEX? %s", tree,
+ * tree.getDepth(), tree.isLexicalized())); }
+ */
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 93d54ed,0000000..044c85f
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@@ -1,257 -1,0 +1,259 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.state_maintenance.KenLMState;
+import org.apache.joshua.util.FormatUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * JNI wrapper for KenLM. This version of KenLM supports two use cases, implemented by the separate
+ * feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
+ * lm/left.hh, returning a state pointer representing the KenLM state, while LangaugeModelFF handles
+ * state by itself and just passes in the ngrams for scoring.
- *
++ *
+ * @author Kenneth Heafield
+ * @author Matt Post post@cs.jhu.edu
+ */
+
+public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(KenLM.class);
+
+ private final long pointer;
+
+ // this is read from the config file, used to set maximum order
+ private final int ngramOrder;
+ // inferred from model file (may be larger than ngramOrder)
+ private final int N;
+
+ private static native long construct(String file_name);
+
+ private static native void destroy(long ptr);
+
+ private static native int order(long ptr);
+
+ private static native boolean registerWord(long ptr, String word, int id);
+
+ private static native float prob(long ptr, int words[]);
+
+ private static native float probForString(long ptr, String[] words);
+
+ private static native boolean isKnownWord(long ptr, String word);
-
++
+ private static native boolean isLmOov(long ptr, int word);
+
+ private static native StateProbPair probRule(long ptr, long pool, long words[]);
-
++
+ private static native float estimateRule(long ptr, long words[]);
+
+ private static native float probString(long ptr, int words[], int start);
+
+ private static native long createPool();
+
+ private static native void destroyPool(long pointer);
+
+ public KenLM(int order, String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ ngramOrder = order;
+ N = order(pointer);
+ }
+
+ /**
+ * Constructor if order is not known.
+ * Order will be inferred from the model.
+ * @param file_name string path to an input file
+ */
+ public KenLM(String file_name) {
+ pointer = initializeSystemLibrary(file_name);
+ N = order(pointer);
+ ngramOrder = N;
+ }
+
+ private long initializeSystemLibrary(String file_name) {
+ try {
+ System.loadLibrary("ken");
+ return construct(file_name);
+ } catch (UnsatisfiedLinkError e) {
+ LOG.error("Can't find libken.so (libken.dylib on OS X) on the Java library path.");
+ throw new KenLMLoadException(e);
+ }
+ }
+
- public class KenLMLoadException extends RuntimeException {
++ public static class KenLMLoadException extends RuntimeException {
+
+ public KenLMLoadException(UnsatisfiedLinkError e) {
+ super(e);
+ }
+ }
+
+ public long createLMPool() {
+ return createPool();
+ }
+
+ public void destroyLMPool(long pointer) {
+ destroyPool(pointer);
+ }
+
+ public void destroy() {
+ destroy(pointer);
+ }
+
++ @Override
+ public int getOrder() {
+ return ngramOrder;
+ }
+
++ @Override
+ public boolean registerWord(String word, int id) {
+ return registerWord(pointer, word, id);
+ }
+
+ public float prob(int[] words) {
+ return prob(pointer, words);
+ }
+
+ /**
+ * Query for n-gram probability using strings.
+ * @param words a string array of words
+ * @return float value denoting probability
+ */
+ public float prob(String[] words) {
+ return probForString(pointer, words);
+ }
+
+ // Apparently Zhifei starts some array indices at 1. Change to 0-indexing.
+ public float probString(int words[], int start) {
+ return probString(pointer, words, start - 1);
+ }
+
+ /**
+ * This function is the bridge to the interface in kenlm/lm/left.hh, which has KenLM score the
+ * whole rule. It takes an array of words and states retrieved from tail nodes (nonterminals in the
+ * rule). Nonterminals have a negative value so KenLM can distinguish them. The sentence number is
+ * needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
+ * state and the LM probability incurred along this rule.
- *
++ *
+ * @param words array of words
+ * @param poolPointer todo
- * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
++ * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
+ * KenLM state and the LM probability incurred along this rule
+ */
+ public StateProbPair probRule(long[] words, long poolPointer) {
+
+ StateProbPair pair = null;
+ try {
+ pair = probRule(pointer, poolPointer, words);
+ } catch (NoSuchMethodError e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ return pair;
+ }
+
+ /**
+ * Public facing function that estimates the cost of a rule, which value is used for sorting
+ * rules during cube pruning.
- *
++ *
+ * @param words array of words
+ * @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
+ */
+ public float estimateRule(long[] words) {
+ float estimate = 0.0f;
+ try {
+ estimate = estimateRule(pointer, words);
+ } catch (NoSuchMethodError e) {
+ throw new RuntimeException(e);
+ }
-
++
+ return estimate;
+ }
+
+ /**
+ * The start symbol for a KenLM is the Vocabulary.START_SYM.
+ * @return "<s>"
+ */
+ public String getStartSymbol() {
+ return Vocabulary.START_SYM;
+ }
-
++
+ /**
+ * Returns whether the given Vocabulary ID is unknown to the
+ * KenLM vocabulary. This can be used for a LanguageModel_OOV features
+ * and does not need to convert to an intermediate string.
+ */
+ @Override
+ public boolean isOov(int wordId) {
+ if (FormatUtils.isNonterminal(wordId)) {
+ throw new IllegalArgumentException("Should not query for nonterminals!");
+ }
+ return isLmOov(pointer, wordId);
+ }
+
+ public boolean isKnownWord(String word) {
+ return isKnownWord(pointer, word);
+ }
+
+
+ /**
+ * Inner class used to hold the results returned from KenLM with left-state minimization. Note
+ * that inner classes have to be static to be accessible from the JNI!
+ */
+ public static class StateProbPair {
+ public KenLMState state = null;
+ public float prob = 0.0f;
+
+ public StateProbPair(long state, float prob) {
+ this.state = new KenLMState(state);
+ this.prob = prob;
+ }
+ }
+
+ @Override
+ public int compareTo(KenLM other) {
+ if (this == other)
+ return 0;
+ else
+ return -1;
+ }
+
+ /**
+ * These functions are used if KenLM is invoked under LanguageModelFF instead of KenLMFF.
+ */
+ @Override
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ return probString(sentence, startIndex);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram, int order) {
+ if (order != N && order != ngram.length)
+ throw new RuntimeException("Lower order not supported.");
+ return prob(ngram);
+ }
+
+ @Override
+ public float ngramLogProbability(int[] ngram) {
+ return prob(ngram);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
index 9bfccb0,0000000..0615077
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
@@@ -1,334 -1,0 +1,284 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder.ff.lm.buildin_lm;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Scanner;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.lm.AbstractLM;
+import org.apache.joshua.decoder.ff.lm.ArpaFile;
+import org.apache.joshua.decoder.ff.lm.ArpaNgram;
+import org.apache.joshua.util.Bits;
+import org.apache.joshua.util.Regex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Relatively memory-compact language model
+ * stored as a reversed-word-order trie.
+ * <p>
+ * The trie itself represents language model context.
+ * <p>
- * Conceptually, each node in the trie stores a map
++ * Conceptually, each node in the trie stores a map
+ * from conditioning word to log probability.
+ * <p>
- * Additionally, each node in the trie stores
++ * Additionally, each node in the trie stores
+ * the backoff weight for that context.
- *
++ *
+ * @author Lane Schwartz
+ * @see <a href="http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html">SRILM ngram-discount documentation</a>
+ */
+public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TrieLM.class);
+
+ /**
+ * Node ID for the root node.
+ */
+ private static final int ROOT_NODE_ID = 0;
+
+
- /**
- * Maps from (node id, word id for child) --> node id of child.
++ /**
++ * Maps from (node id, word id for child) --> node id of child.
+ */
+ private final Map<Long,Integer> children;
+
+ /**
- * Maps from (node id, word id for lookup word) -->
- * log prob of lookup word given context
- *
++ * Maps from (node id, word id for lookup word) -->
++ * log prob of lookup word given context
++ *
+ * (the context is defined by where you are in the tree).
+ */
+ private final Map<Long,Float> logProbs;
+
+ /**
- * Maps from (node id) -->
- * backoff weight for that context
- *
++ * Maps from (node id) -->
++ * backoff weight for that context
++ *
+ * (the context is defined by where you are in the tree).
+ */
+ private final Map<Integer,Float> backoffs;
+
+ public TrieLM(Vocabulary vocab, String file) throws FileNotFoundException {
+ this(new ArpaFile(file,vocab));
+ }
+
+ /**
+ * Constructs a language model object from the specified ARPA file.
- *
++ *
+ * @param arpaFile input ARPA file
+ * @throws FileNotFoundException if the input file cannot be located
+ */
+ public TrieLM(ArpaFile arpaFile) throws FileNotFoundException {
+ super(Vocabulary.size(), arpaFile.getOrder());
+
+ int ngramCounts = arpaFile.size();
+ LOG.debug("ARPA file contains {} n-grams", ngramCounts);
+
+ this.children = new HashMap<>(ngramCounts);
+ this.logProbs = new HashMap<>(ngramCounts);
+ this.backoffs = new HashMap<>(ngramCounts);
+
+ int nodeCounter = 0;
+
+ int lineNumber = 0;
+ for (ArpaNgram ngram : arpaFile) {
+ lineNumber += 1;
+ if (lineNumber % 100000 == 0){
+ LOG.info("Line: {}", lineNumber);
+ }
+
+ LOG.debug("{}-gram: ({} | {})", ngram.order(), ngram.getWord(),
+ Arrays.toString(ngram.getContext()));
+ int word = ngram.getWord();
+
+ int[] context = ngram.getContext();
+
+ {
+ // Find where the log prob should be stored
+ int contextNodeID = ROOT_NODE_ID;
+ {
+ for (int i=context.length-1; i>=0; i--) {
+ long key = Bits.encodeAsLong(contextNodeID, context[i]);
+ int childID;
+ if (children.containsKey(key)) {
+ childID = children.get(key);
+ } else {
+ childID = ++nodeCounter;
+ LOG.debug("children.put({}:{}, {})", contextNodeID, context[i], childID);
+ children.put(key, childID);
+ }
+ contextNodeID = childID;
+ }
+ }
+
+ // Store the log prob for this n-gram at this node in the trie
+ {
+ long key = Bits.encodeAsLong(contextNodeID, word);
+ float logProb = ngram.getValue();
+ LOG.debug("logProbs.put({}:{}, {}", contextNodeID, word, logProb);
+ this.logProbs.put(key, logProb);
+ }
+ }
+
+ {
+ // Find where the backoff should be stored
+ int backoffNodeID = ROOT_NODE_ID;
- {
++ {
+ long backoffNodeKey = Bits.encodeAsLong(backoffNodeID, word);
+ int wordChildID;
+ if (children.containsKey(backoffNodeKey)) {
+ wordChildID = children.get(backoffNodeKey);
+ } else {
+ wordChildID = ++nodeCounter;
+ LOG.debug("children.put({}: {}, {})", backoffNodeID, word, wordChildID);
+ children.put(backoffNodeKey, wordChildID);
+ }
+ backoffNodeID = wordChildID;
+
+ for (int i=context.length-1; i>=0; i--) {
+ long key = Bits.encodeAsLong(backoffNodeID, context[i]);
+ int childID;
+ if (children.containsKey(key)) {
+ childID = children.get(key);
+ } else {
+ childID = ++nodeCounter;
+ LOG.debug("children.put({}:{}, {})", backoffNodeID, context[i], childID);
+ children.put(key, childID);
+ }
+ backoffNodeID = childID;
+ }
+ }
+
+ // Store the backoff for this n-gram at this node in the trie
+ {
+ float backoff = ngram.getBackoff();
+ LOG.debug("backoffs.put({}:{}, {})", backoffNodeID, word, backoff);
+ this.backoffs.put(backoffNodeID, backoff);
+ }
+ }
+
+ }
+ }
+
+
+ @Override
- protected double logProbabilityOfBackoffState_helper(
- int[] ngram, int order, int qtyAdditionalBackoffWeight
- ) {
++ protected double logProbabilityOfBackoffState_helper(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
+ throw new UnsupportedOperationException("probabilityOfBackoffState_helper undefined for TrieLM");
+ }
+
+ @Override
+ protected float ngramLogProbability_helper(int[] ngram, int order) {
-
- // float logProb = (float) -JoshuaConfiguration.lm_ceiling_cost;//Float.NEGATIVE_INFINITY; // log(0.0f)
- float backoff = 0.0f; // log(1.0f)
-
- int i = ngram.length - 1;
- int word = ngram[i];
- i -= 1;
-
- int nodeID = ROOT_NODE_ID;
-
- while (true) {
-
- {
- long key = Bits.encodeAsLong(nodeID, word);
- if (logProbs.containsKey(key)) {
- // logProb = logProbs.get(key);
- backoff = 0.0f; // log(0.0f)
- }
- }
-
- if (i < 0) {
- break;
- }
-
- {
- long key = Bits.encodeAsLong(nodeID, ngram[i]);
-
- if (children.containsKey(key)) {
- nodeID = children.get(key);
-
- backoff += backoffs.get(nodeID);
-
- i -= 1;
-
- } else {
- break;
- }
- }
-
- }
-
- // double result = logProb + backoff;
- // if (result < -JoshuaConfiguration.lm_ceiling_cost) {
- // result = -JoshuaConfiguration.lm_ceiling_cost;
- // }
- //
- // return result;
- return (Float) null;
++ throw new UnsupportedOperationException();
+ }
+
+ public Map<Long,Integer> getChildren() {
+ return this.children;
+ }
+
+ public static void main(String[] args) throws IOException {
+
+ LOG.info("Constructing ARPA file");
+ ArpaFile arpaFile = new ArpaFile(args[0]);
+
+ LOG.info("Getting symbol table");
+ Vocabulary vocab = arpaFile.getVocab();
+
+ LOG.info("Constructing TrieLM");
+ TrieLM lm = new TrieLM(arpaFile);
+
+ int n = Integer.valueOf(args[2]);
+ LOG.info("N-gram order will be {}", n);
+
- Scanner scanner = new Scanner(new File(args[1]));
++ try (Scanner scanner = new Scanner(new File(args[1]));) {
++ LinkedList<String> wordList = new LinkedList<>();
++ LinkedList<String> window = new LinkedList<>();
+
- LinkedList<String> wordList = new LinkedList<>();
- LinkedList<String> window = new LinkedList<>();
++ LOG.info("Starting to scan {}", args[1]);
++ while (scanner.hasNext()) {
+
- LOG.info("Starting to scan {}", args[1]);
- while (scanner.hasNext()) {
++ LOG.info("Getting next line...");
++ String line = scanner.nextLine();
++ LOG.info("Line: {}", line);
+
- LOG.info("Getting next line...");
- String line = scanner.nextLine();
- LOG.info("Line: {}", line);
++ String[] words = Regex.spaces.split(line);
++ wordList.clear();
+
- String[] words = Regex.spaces.split(line);
- wordList.clear();
++ wordList.add("<s>");
++ Collections.addAll(wordList, words);
++ wordList.add("</s>");
+
- wordList.add("<s>");
- Collections.addAll(wordList, words);
- wordList.add("</s>");
-
- ArrayList<Integer> sentence = new ArrayList<>();
- // int[] ids = new int[wordList.size()];
- for (String aWordList : wordList) {
- sentence.add(Vocabulary.id(aWordList));
- // ids[i] = ;
- }
++ ArrayList<Integer> sentence = new ArrayList<>();
++ // int[] ids = new int[wordList.size()];
++ for (String aWordList : wordList) {
++ sentence.add(Vocabulary.id(aWordList));
++ // ids[i] = ;
++ }
+
++ while (!wordList.isEmpty()) {
++ window.clear();
+
++ {
++ int i = 0;
++ for (String word : wordList) {
++ if (i >= n)
++ break;
++ window.add(word);
++ i++;
++ }
++ wordList.remove();
++ }
+
- while (! wordList.isEmpty()) {
- window.clear();
++ {
++ int i = 0;
++ int[] wordIDs = new int[window.size()];
++ for (String word : window) {
++ wordIDs[i] = Vocabulary.id(word);
++ i++;
++ }
+
- {
- int i=0;
- for (String word : wordList) {
- if (i>=n) break;
- window.add(word);
- i++;
++ LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
+ }
- wordList.remove();
+ }
+
- {
- int i=0;
- int[] wordIDs = new int[window.size()];
- for (String word : window) {
- wordIDs[i] = Vocabulary.id(word);
- i++;
- }
++ double logProb = lm.sentenceLogProbability(sentence, n, 2);// .ngramLogProbability(ids,
++ // n);
++ double prob = Math.exp(logProb);
+
- LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
- }
++ LOG.info("Total logProb = {}", logProb);
++ LOG.info("Total prob = {}", prob);
+ }
-
- double logProb = lm.sentenceLogProbability(sentence, n, 2);//.ngramLogProbability(ids, n);
- double prob = Math.exp(logProb);
-
- LOG.info("Total logProb = {}", logProb);
- LOG.info("Total prob = {}", prob);
+ }
-
+ }
+
+ @Override
+ public boolean isOov(int id) {
+ throw new RuntimeException("Not implemented!");
+ }
+
+}
[08/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
index 5861052,0000000..838279b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/GrammarPacker.java
@@@ -1,936 -1,0 +1,940 @@@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import static org.apache.joshua.decoder.ff.tm.OwnerMap.UNKNOWN_OWNER_ID;
+import static org.apache.joshua.decoder.ff.tm.packed.PackedGrammar.VOCABULARY_FILENAME;
+
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Queue;
+import java.util.TreeMap;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.RuleFactory;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
+import org.apache.joshua.decoder.ff.tm.format.MosesFormatReader;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.encoding.EncoderConfiguration;
+import org.apache.joshua.util.encoding.FeatureTypeAnalyzer;
+import org.apache.joshua.util.encoding.IntEncoder;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public class GrammarPacker {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GrammarPacker.class);
+
+ /**
+ * The packed grammar version number. Increment this any time you add new features, and update
+ * the documentation.
- *
++ *
+ * Version history:
- *
++ *
+ * - 3 (May 2016). This was the first version that was marked. It removed the special phrase-
+ * table packing that packed phrases without the [X,1] on the source and target sides, which
+ * then required special handling in the decoder to use for phrase-based decoding.
- *
++ *
+ * - 4 (August 2016). Phrase-based decoding rewritten to represent phrases without a builtin
+ * nonterminal. Instead, cost-less glue rules are used in phrase-based decoding. This eliminates
+ * the need for special handling of phrase grammars (except for having to add a LHS), and lets
+ * phrase grammars be used in both hierarchical and phrase-based decoding without conversion.
+ *
+ */
+ public static final int VERSION = 4;
-
++
+ // Size limit for slice in bytes.
+ private static int DATA_SIZE_LIMIT = (int) (Integer.MAX_VALUE * 0.8);
+ // Estimated average number of feature entries for one rule.
+ private static int DATA_SIZE_ESTIMATE = 20;
+
+ private static final String SOURCE_WORDS_SEPARATOR = " ||| ";
+
+ // Output directory name.
+ private String output;
+
+ // Input grammar to be packed.
+ private String grammar;
+
+ public String getGrammar() {
+ return grammar;
+ }
+
+ public String getOutputDirectory() {
+ return output;
+ }
+
+ // Approximate maximum size of a slice in number of rules
+ private int approximateMaximumSliceSize;
+
+ private boolean labeled;
+
+ private boolean packAlignments;
+ private boolean grammarAlignments;
+ private String alignments;
+
+ private FeatureTypeAnalyzer types;
+ private EncoderConfiguration encoderConfig;
+
+ private String dump;
+
+ private int max_source_len;
+
+ public GrammarPacker(String grammar_filename, String config_filename, String output_filename,
+ String alignments_filename, String featuredump_filename, boolean grammar_alignments,
+ int approximateMaximumSliceSize)
+ throws IOException {
+ this.labeled = true;
+ this.grammar = grammar_filename;
+ this.output = output_filename;
+ this.dump = featuredump_filename;
+ this.grammarAlignments = grammar_alignments;
+ this.approximateMaximumSliceSize = approximateMaximumSliceSize;
+ this.max_source_len = 0;
+
+ // TODO: Always open encoder config? This is debatable.
+ this.types = new FeatureTypeAnalyzer(true);
+
+ this.alignments = alignments_filename;
+ packAlignments = grammarAlignments || (alignments != null);
+ if (!packAlignments) {
+ LOG.info("No alignments file or grammar specified, skipping.");
+ } else if (alignments != null && !new File(alignments_filename).exists()) {
+ throw new RuntimeException("Alignments file does not exist: " + alignments);
+ }
+
+ if (config_filename != null) {
+ readConfig(config_filename);
+ types.readConfig(config_filename);
+ } else {
+ LOG.info("No config specified. Attempting auto-detection of feature types.");
+ }
+ LOG.info("Approximate maximum slice size (in # of rules) set to {}", approximateMaximumSliceSize);
+
+ File working_dir = new File(output);
+ working_dir.mkdir();
+ if (!working_dir.exists()) {
+ throw new RuntimeException("Failed creating output directory.");
+ }
+ }
+
+ private void readConfig(String config_filename) throws IOException {
- LineReader reader = new LineReader(config_filename);
- while (reader.hasNext()) {
- // Clean up line, chop comments off and skip if the result is empty.
- String line = reader.next().trim();
- if (line.indexOf('#') != -1)
- line = line.substring(0, line.indexOf('#'));
- if (line.isEmpty())
- continue;
- String[] fields = line.split("[\\s]+");
-
- if (fields.length < 2) {
- throw new RuntimeException("Incomplete line in config.");
- }
- if ("slice_size".equals(fields[0])) {
- // Number of records to concurrently load into memory for sorting.
- approximateMaximumSliceSize = Integer.parseInt(fields[1]);
++ try(LineReader reader = new LineReader(config_filename);) {
++ while (reader.hasNext()) {
++ // Clean up line, chop comments off and skip if the result is empty.
++ String line = reader.next().trim();
++ if (line.indexOf('#') != -1)
++ line = line.substring(0, line.indexOf('#'));
++ if (line.isEmpty())
++ continue;
++ String[] fields = line.split("[\\s]+");
++
++ if (fields.length < 2) {
++ throw new RuntimeException("Incomplete line in config.");
++ }
++ if ("slice_size".equals(fields[0])) {
++ // Number of records to concurrently load into memory for sorting.
++ approximateMaximumSliceSize = Integer.parseInt(fields[1]);
++ }
+ }
+ }
- reader.close();
+ }
+
+ /**
+ * Executes the packing.
- *
++ *
+ * @throws IOException if there is an error reading the grammar
+ */
+ public void pack() throws IOException {
+ LOG.info("Beginning exploration pass.");
+
+ // Explore pass. Learn vocabulary and feature value histograms.
+ LOG.info("Exploring: {}", grammar);
+
+ HieroFormatReader grammarReader = getGrammarReader();
+ explore(grammarReader);
+
+ LOG.info("Exploration pass complete. Freezing vocabulary and finalizing encoders.");
+ if (dump != null) {
+ PrintWriter dump_writer = new PrintWriter(dump);
+ dump_writer.println(types.toString());
+ dump_writer.close();
+ }
+
+ types.inferTypes(this.labeled);
+ LOG.info("Type inference complete.");
+
+ LOG.info("Finalizing encoding.");
+
+ LOG.info("Writing encoding.");
+ types.write(output + File.separator + "encoding");
+
+ writeVocabulary();
+
+ String configFile = output + File.separator + "config";
+ LOG.info("Writing config to '{}'", configFile);
+ // Write config options
+ FileWriter config = new FileWriter(configFile);
+ config.write(String.format("version = %d\n", VERSION));
+ config.write(String.format("max-source-len = %d\n", max_source_len));
+ config.close();
+
+ // Read previously written encoder configuration to match up to changed
+ // vocabulary id's.
+ LOG.info("Reading encoding.");
+ encoderConfig = new EncoderConfiguration();
+ encoderConfig.load(output + File.separator + "encoding");
+
+ LOG.info("Beginning packing pass.");
+ // Actual binarization pass. Slice and pack source, target and data.
+ grammarReader = getGrammarReader();
+ LineReader alignment_reader = null;
+ if (packAlignments && !grammarAlignments)
+ alignment_reader = new LineReader(alignments);
+ binarize(grammarReader, alignment_reader);
+ LOG.info("Packing complete.");
+
+ LOG.info("Packed grammar in: {}", output);
+ LOG.info("Done.");
+ }
+
+ /**
+ * Returns a reader that turns whatever file format is found into unowned Hiero grammar rules.
+ * This means, features are NOT prepended with an owner string at packing time.
- *
++ *
+ * @param grammarFile
+ * @return GrammarReader of correct Format
+ * @throws IOException
+ */
+ private HieroFormatReader getGrammarReader() throws IOException {
- LineReader reader = new LineReader(grammar);
- String line = reader.next();
- if (line.startsWith("[")) {
- return new HieroFormatReader(grammar, UNKNOWN_OWNER_ID);
- } else {
- return new MosesFormatReader(grammar, UNKNOWN_OWNER_ID);
++ try (LineReader reader = new LineReader(grammar);) {
++ String line = reader.next();
++ if (line.startsWith("["))
++ return new HieroFormatReader(grammar, UNKNOWN_OWNER_ID);
++ else
++ return new MosesFormatReader(grammar, UNKNOWN_OWNER_ID);
+ }
+ }
+
+ /**
- * This first pass over the grammar
++ * This first pass over the grammar
+ * @param reader
+ */
+ private void explore(HieroFormatReader reader) {
+
+ // We always assume a labeled grammar. Unlabeled features are assumed to be dense and to always
+ // appear in the same order. They are assigned numeric names in order of appearance.
+ this.types.setLabeled(true);
+
+ for (Rule rule : reader) {
+
+ max_source_len = Math.max(max_source_len, rule.getSource().length);
+
+ /* Add symbols to vocabulary.
+ * NOTE: In case of nonterminals, we add both stripped versions ("[X]")
+ * and "[X,1]" to the vocabulary.
- *
++ *
+ * TODO: MJP May 2016: Is it necessary to add [X,1]? This is currently being done in
- * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
++ * {@link HieroFormatReader}, which is called by {@link MosesFormatReader}.
+ */
+
+ // pass the value through the appropriate encoder.
+ for (final Entry<Integer, Float> entry : rule.getFeatureVector().entrySet()) {
+ types.observe(entry.getKey(), entry.getValue());
+ }
+ }
+ }
+
+ /**
+ * Returns a String encoding the first two source words.
+ * If there is only one source word, use empty string for the second.
+ */
+ private String getFirstTwoSourceWords(final String[] source_words) {
+ return source_words[0] + SOURCE_WORDS_SEPARATOR + ((source_words.length > 1) ? source_words[1] : "");
+ }
+
+ private void binarize(HieroFormatReader grammarReader, LineReader alignment_reader) throws IOException {
+ int counter = 0;
+ int slice_counter = 0;
+ int num_slices = 0;
+
+ boolean ready_to_flush = false;
+ // to determine when flushing is possible
+ String prev_first_two_source_words = null;
+
+ PackingTrie<SourceValue> source_trie = new PackingTrie<SourceValue>();
+ PackingTrie<TargetValue> target_trie = new PackingTrie<TargetValue>();
+ FeatureBuffer feature_buffer = new FeatureBuffer();
+
+ AlignmentBuffer alignment_buffer = null;
+ if (packAlignments)
+ alignment_buffer = new AlignmentBuffer();
+
+ TreeMap<Integer, Float> features = new TreeMap<Integer, Float>();
+ for (Rule rule : grammarReader) {
+ counter++;
+ slice_counter++;
+
+ String lhs_word = Vocabulary.word(rule.getLHS());
+ String[] source_words = rule.getSourceWords().split("\\s+");
+ String[] target_words = rule.getTargetWords().split("\\s+");
+
+ // Reached slice limit size, indicate that we're closing up.
+ if (!ready_to_flush
+ && (slice_counter > approximateMaximumSliceSize
+ || feature_buffer.overflowing()
+ || (packAlignments && alignment_buffer.overflowing()))) {
+ ready_to_flush = true;
+ // store the first two source words when slice size limit was reached
+ prev_first_two_source_words = getFirstTwoSourceWords(source_words);
+ }
+ // ready to flush
+ if (ready_to_flush) {
+ final String first_two_source_words = getFirstTwoSourceWords(source_words);
+ // the grammar can only be partitioned at the level of first two source word changes.
+ // Thus, we can only flush if the current first two source words differ from the ones
+ // when the slice size limit was reached.
+ if (!first_two_source_words.equals(prev_first_two_source_words)) {
+ LOG.warn("ready to flush and first two words have changed ({} vs. {})",
+ prev_first_two_source_words, first_two_source_words);
+ LOG.info("flushing {} rules to slice.", slice_counter);
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ source_trie.clear();
+ target_trie.clear();
+ feature_buffer.clear();
+ if (packAlignments)
+ alignment_buffer.clear();
+
+ num_slices++;
+ slice_counter = 0;
+ ready_to_flush = false;
+ }
+ }
+
+ int alignment_index = -1;
+ // If present, process alignments.
+ if (packAlignments) {
+ byte[] alignments = null;
+ if (grammarAlignments) {
+ alignments = rule.getAlignment();
+ } else {
+ if (!alignment_reader.hasNext()) {
+ LOG.error("No more alignments starting in line {}", counter);
+ throw new RuntimeException("No more alignments starting in line " + counter);
+ }
+ alignments = RuleFactory.parseAlignmentString(alignment_reader.next().trim());
+ }
+ alignment_index = alignment_buffer.add(alignments);
+ }
+
+ // Process features.
+ // Implicitly sort via TreeMap, write to data buffer, remember position
+ // to pass on to the source trie node.
+ features.clear();
+ for (Entry<Integer, Float> entry : rule.getFeatureVector().entrySet()) {
+ int featureId = entry.getKey();
+ float featureValue = entry.getValue();
+ if (featureValue != 0f) {
+ features.put(encoderConfig.innerId(featureId), featureValue);
+ }
+ }
+
+ int features_index = feature_buffer.add(features);
+
+ // Sanity check on the data block index.
+ if (packAlignments && features_index != alignment_index) {
+ LOG.error("Block index mismatch between features ({}) and alignments ({}).",
+ features_index, alignment_index);
+ throw new RuntimeException("Data block index mismatch.");
+ }
+
+ // Process source side.
+ SourceValue sv = new SourceValue(Vocabulary.id(lhs_word), features_index);
+ int[] source = new int[source_words.length];
+ for (int i = 0; i < source_words.length; i++) {
+ if (FormatUtils.isNonterminal(source_words[i]))
+ source[i] = Vocabulary.id(FormatUtils.stripNonTerminalIndex(source_words[i]));
+ else
+ source[i] = Vocabulary.id(source_words[i]);
+ }
+ source_trie.add(source, sv);
+
+ // Process target side.
+ TargetValue tv = new TargetValue(sv);
+ int[] target = new int[target_words.length];
+ for (int i = 0; i < target_words.length; i++) {
+ if (FormatUtils.isNonterminal(target_words[i])) {
+ target[target_words.length - (i + 1)] = -FormatUtils.getNonterminalIndex(target_words[i]);
+ } else {
+ target[target_words.length - (i + 1)] = Vocabulary.id(target_words[i]);
+ }
+ }
+ target_trie.add(target, tv);
+ }
+ // flush last slice and clear buffers
+ flush(source_trie, target_trie, feature_buffer, alignment_buffer, num_slices);
+ }
+
+ /**
+ * Serializes the source, target and feature data structures into interlinked binary files. Target
+ * is written first, into a skeletal (node don't carry any data) upward-pointing trie, updating
+ * the linking source trie nodes with the position once it is known. Source and feature data are
+ * written simultaneously. The source structure is written into a downward-pointing trie and
+ * stores the rule's lhs as well as links to the target and feature stream. The feature stream is
+ * prompted to write out a block
- *
++ *
+ * @param source_trie
+ * @param target_trie
+ * @param feature_buffer
+ * @param id
+ * @throws IOException
+ */
+ private void flush(PackingTrie<SourceValue> source_trie,
+ PackingTrie<TargetValue> target_trie, FeatureBuffer feature_buffer,
+ AlignmentBuffer alignment_buffer, int id) throws IOException {
+ // Make a slice object for this piece of the grammar.
+ PackingFileTuple slice = new PackingFileTuple("slice_" + String.format("%05d", id));
+ // Pull out the streams for source, target and data output.
+ DataOutputStream source_stream = slice.getSourceOutput();
+ DataOutputStream target_stream = slice.getTargetOutput();
+ DataOutputStream target_lookup_stream = slice.getTargetLookupOutput();
+ DataOutputStream feature_stream = slice.getFeatureOutput();
+ DataOutputStream alignment_stream = slice.getAlignmentOutput();
+
+ Queue<PackingTrie<TargetValue>> target_queue;
+ Queue<PackingTrie<SourceValue>> source_queue;
+
+ // The number of bytes both written into the source stream and
+ // buffered in the source queue.
+ int source_position;
+ // The number of bytes written into the target stream.
+ int target_position;
+
+ // Add trie root into queue, set target position to 0 and set cumulated
+ // size to size of trie root.
+ target_queue = new LinkedList<PackingTrie<TargetValue>>();
+ target_queue.add(target_trie);
+ target_position = 0;
+
+ // Target lookup table for trie levels.
+ int current_level_size = 1;
+ int next_level_size = 0;
+ ArrayList<Integer> target_lookup = new ArrayList<Integer>();
+
+ // Packing loop for upwards-pointing target trie.
+ while (!target_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<TargetValue> node = target_queue.poll();
+ // Register that this is where we're writing the node to.
+ node.address = target_position;
+ // Tell source nodes that we're writing to this position in the file.
+ for (TargetValue tv : node.values)
+ tv.parent.target = node.address;
+ // Write link to parent.
+ if (node.parent != null)
+ target_stream.writeInt(node.parent.address);
+ else
+ target_stream.writeInt(-1);
+ target_stream.writeInt(node.symbol);
+ // Enqueue children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<TargetValue> child = node.children.get(k);
+ target_queue.add(child);
+ }
+ target_position += node.size(false, true);
+ next_level_size += node.children.descendingKeySet().size();
+
+ current_level_size--;
+ if (current_level_size == 0) {
+ target_lookup.add(target_position);
+ current_level_size = next_level_size;
+ next_level_size = 0;
+ }
+ }
+ target_lookup_stream.writeInt(target_lookup.size());
+ for (int i : target_lookup)
+ target_lookup_stream.writeInt(i);
+ target_lookup_stream.close();
+
+ // Setting up for source and data writing.
+ source_queue = new LinkedList<PackingTrie<SourceValue>>();
+ source_queue.add(source_trie);
+ source_position = source_trie.size(true, false);
+ source_trie.address = target_position;
+
+ // Ready data buffers for writing.
+ feature_buffer.initialize();
+ if (packAlignments)
+ alignment_buffer.initialize();
+
+ // Packing loop for downwards-pointing source trie.
+ while (!source_queue.isEmpty()) {
+ // Pop top of queue.
+ PackingTrie<SourceValue> node = source_queue.poll();
+ // Write number of children.
+ source_stream.writeInt(node.children.size());
+ // Write links to children.
+ for (int k : node.children.descendingKeySet()) {
+ PackingTrie<SourceValue> child = node.children.get(k);
+ // Enqueue child.
+ source_queue.add(child);
+ // Child's address will be at the current end of the queue.
+ child.address = source_position;
+ // Advance cumulated size by child's size.
+ source_position += child.size(true, false);
+ // Write the link.
+ source_stream.writeInt(k);
+ source_stream.writeInt(child.address);
+ }
+ // Write number of data items.
+ source_stream.writeInt(node.values.size());
+ // Write lhs and links to target and data.
+ for (SourceValue sv : node.values) {
+ int feature_block_index = feature_buffer.write(sv.data);
+ if (packAlignments) {
+ int alignment_block_index = alignment_buffer.write(sv.data);
+ if (alignment_block_index != feature_block_index) {
+ LOG.error("Block index mismatch.");
+ throw new RuntimeException("Block index mismatch: alignment (" + alignment_block_index
+ + ") and features (" + feature_block_index + ") don't match.");
+ }
+ }
+ source_stream.writeInt(sv.lhs);
+ source_stream.writeInt(sv.target);
+ source_stream.writeInt(feature_block_index);
+ }
+ }
+ // Flush the data stream.
+ feature_buffer.flush(feature_stream);
+ if (packAlignments)
+ alignment_buffer.flush(alignment_stream);
+
+ target_stream.close();
+ source_stream.close();
+ feature_stream.close();
+ if (packAlignments)
+ alignment_stream.close();
+ }
+
+ public void writeVocabulary() throws IOException {
+ final String vocabularyFilename = output + File.separator + VOCABULARY_FILENAME;
+ LOG.info("Writing vocabulary to {}", vocabularyFilename);
+ Vocabulary.write(vocabularyFilename);
+ }
+
+ /**
+ * Integer-labeled, doubly-linked trie with some provisions for packing.
- *
++ *
+ * @author Juri Ganitkevitch
- *
++ *
+ * @param <D> The trie's value type.
+ */
+ class PackingTrie<D extends PackingTrieValue> {
+ int symbol;
+ PackingTrie<D> parent;
+
+ TreeMap<Integer, PackingTrie<D>> children;
+ List<D> values;
+
+ int address;
+
+ PackingTrie() {
+ address = -1;
+
+ symbol = 0;
+ parent = null;
+
+ children = new TreeMap<Integer, PackingTrie<D>>();
+ values = new ArrayList<D>();
+ }
+
+ PackingTrie(PackingTrie<D> parent, int symbol) {
+ this();
+ this.parent = parent;
+ this.symbol = symbol;
+ }
+
+ void add(int[] path, D value) {
+ add(path, 0, value);
+ }
+
+ private void add(int[] path, int index, D value) {
+ if (index == path.length)
+ this.values.add(value);
+ else {
+ PackingTrie<D> child = children.get(path[index]);
+ if (child == null) {
+ child = new PackingTrie<D>(this, path[index]);
+ children.put(path[index], child);
+ }
+ child.add(path, index + 1, value);
+ }
+ }
+
+ /**
+ * Calculate the size (in ints) of a packed trie node. Distinguishes downwards pointing (parent
+ * points to children) from upwards pointing (children point to parent) tries, as well as
+ * skeletal (no data, just the labeled links) and non-skeletal (nodes have a data block)
+ * packing.
- *
++ *
+ * @param downwards Are we packing into a downwards-pointing trie?
+ * @param skeletal Are we packing into a skeletal trie?
- *
++ *
+ * @return Number of bytes the trie node would occupy.
+ */
+ int size(boolean downwards, boolean skeletal) {
+ int size = 0;
+ if (downwards) {
+ // Number of children and links to children.
+ size = 1 + 2 * children.size();
+ } else {
+ // Link to parent.
+ size += 2;
+ }
+ // Non-skeletal packing: number of data items.
+ if (!skeletal)
+ size += 1;
+ // Non-skeletal packing: write size taken up by data items.
+ if (!skeletal && !values.isEmpty())
+ size += values.size() * values.get(0).size();
+
+ return size;
+ }
+
+ void clear() {
+ children.clear();
+ values.clear();
+ }
+ }
+
+ interface PackingTrieValue {
+ int size();
+ }
+
+ class SourceValue implements PackingTrieValue {
+ int lhs;
+ int data;
+ int target;
+
+ public SourceValue() {
+ }
+
+ SourceValue(int lhs, int data) {
+ this.lhs = lhs;
+ this.data = data;
+ }
+
+ void setTarget(int target) {
+ this.target = target;
+ }
+
++ @Override
+ public int size() {
+ return 3;
+ }
+ }
+
+ class TargetValue implements PackingTrieValue {
+ SourceValue parent;
+
+ TargetValue(SourceValue parent) {
+ this.parent = parent;
+ }
+
++ @Override
+ public int size() {
+ return 0;
+ }
+ }
+
+ abstract class PackingBuffer<T> {
+ private byte[] backing;
+ protected ByteBuffer buffer;
+
+ protected ArrayList<Integer> memoryLookup;
+ protected int totalSize;
+ protected ArrayList<Integer> onDiskOrder;
+
+ PackingBuffer() throws IOException {
+ allocate();
+ memoryLookup = new ArrayList<Integer>();
+ onDiskOrder = new ArrayList<Integer>();
+ totalSize = 0;
+ }
+
+ abstract int add(T item);
+
+ // Allocate a reasonably-sized buffer for the feature data.
+ private void allocate() {
+ backing = new byte[approximateMaximumSliceSize * DATA_SIZE_ESTIMATE];
+ buffer = ByteBuffer.wrap(backing);
+ }
+
+ // Reallocate the backing array and buffer, copies data over.
+ protected void reallocate() {
+ if (backing.length == Integer.MAX_VALUE)
+ return;
+ long attempted_length = backing.length * 2l;
+ int new_length;
+ // Detect overflow.
+ if (attempted_length >= Integer.MAX_VALUE)
+ new_length = Integer.MAX_VALUE;
+ else
+ new_length = (int) attempted_length;
+ byte[] new_backing = new byte[new_length];
+ System.arraycopy(backing, 0, new_backing, 0, backing.length);
+ int old_position = buffer.position();
+ ByteBuffer new_buffer = ByteBuffer.wrap(new_backing);
+ new_buffer.position(old_position);
+ buffer = new_buffer;
+ backing = new_backing;
+ }
+
+ /**
+ * Prepare the data buffer for disk writing.
+ */
+ void initialize() {
+ onDiskOrder.clear();
+ }
+
+ /**
+ * Enqueue a data block for later writing.
- *
++ *
+ * @param block_index The index of the data block to add to writing queue.
+ * @return The to-be-written block's output index.
+ */
+ int write(int block_index) {
+ onDiskOrder.add(block_index);
+ return onDiskOrder.size() - 1;
+ }
+
+ /**
+ * Performs the actual writing to disk in the order specified by calls to write() since the last
+ * call to initialize().
- *
++ *
+ * @param out
+ * @throws IOException
+ */
+ void flush(DataOutputStream out) throws IOException {
+ writeHeader(out);
+ int size;
+ int block_address;
+ for (int block_index : onDiskOrder) {
+ block_address = memoryLookup.get(block_index);
+ size = blockSize(block_index);
+ out.write(backing, block_address, size);
+ }
+ }
+
+ void clear() {
+ buffer.clear();
+ memoryLookup.clear();
+ onDiskOrder.clear();
+ }
+
+ boolean overflowing() {
+ return (buffer.position() >= DATA_SIZE_LIMIT);
+ }
+
+ private void writeHeader(DataOutputStream out) throws IOException {
+ if (out.size() == 0) {
+ out.writeInt(onDiskOrder.size());
+ out.writeInt(totalSize);
+ int disk_position = headerSize();
+ for (int block_index : onDiskOrder) {
+ out.writeInt(disk_position);
+ disk_position += blockSize(block_index);
+ }
+ } else {
+ throw new RuntimeException("Got a used stream for header writing.");
+ }
+ }
+
+ private int headerSize() {
+ // One integer for each data block, plus number of blocks and total size.
+ return 4 * (onDiskOrder.size() + 2);
+ }
+
+ private int blockSize(int block_index) {
+ int block_address = memoryLookup.get(block_index);
+ return (block_index < memoryLookup.size() - 1 ? memoryLookup.get(block_index + 1) : totalSize)
+ - block_address;
+ }
+ }
+
+ class FeatureBuffer extends PackingBuffer<TreeMap<Integer, Float>> {
+
+ private IntEncoder idEncoder;
+
+ FeatureBuffer() throws IOException {
+ super();
+ idEncoder = types.getIdEncoder();
+ LOG.info("Encoding feature ids in: {}", idEncoder.getKey());
+ }
+
+ /**
+ * Add a block of features to the buffer.
- *
++ *
+ * @param features TreeMap with the features for one rule.
+ * @return The index of the resulting data block.
+ */
++ @Override
+ int add(TreeMap<Integer, Float> features) {
+ int data_position = buffer.position();
+
+ // Over-estimate how much room this addition will need: for each
+ // feature (ID_SIZE for label, "upper bound" of 4 for the value), plus ID_SIZE for
+ // the number of features. If this won't fit, reallocate the buffer.
+ int size_estimate = (4 + EncoderConfiguration.ID_SIZE) * features.size()
+ + EncoderConfiguration.ID_SIZE;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write features to buffer.
+ idEncoder.write(buffer, features.size());
+ for (Integer k : features.descendingKeySet()) {
+ float v = features.get(k);
+ // Sparse features.
+ if (v != 0.0) {
+ idEncoder.write(buffer, k);
+ encoderConfig.encoder(k).write(buffer, v);
+ }
+ }
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class AlignmentBuffer extends PackingBuffer<byte[]> {
+
+ AlignmentBuffer() throws IOException {
+ super();
+ }
+
+ /**
+ * Add a rule alignments to the buffer.
- *
++ *
+ * @param alignments a byte array with the alignment points for one rule.
+ * @return The index of the resulting data block.
+ */
++ @Override
+ int add(byte[] alignments) {
+ int data_position = buffer.position();
+ int size_estimate = alignments.length + 1;
+ if (buffer.capacity() - buffer.position() <= size_estimate)
+ reallocate();
+
+ // Write alignment points to buffer.
+ buffer.put((byte) (alignments.length / 2));
+ buffer.put(alignments);
+
+ // Store position the block was written to.
+ memoryLookup.add(data_position);
+ // Update total size (in bytes).
+ totalSize = buffer.position();
+ // Return block index.
+ return memoryLookup.size() - 1;
+ }
+ }
+
+ class PackingFileTuple implements Comparable<PackingFileTuple> {
+ private File sourceFile;
+ private File targetLookupFile;
+ private File targetFile;
+
+ private File featureFile;
+ private File alignmentFile;
+
+ PackingFileTuple(String prefix) {
+ sourceFile = new File(output + File.separator + prefix + ".source");
+ targetFile = new File(output + File.separator + prefix + ".target");
+ targetLookupFile = new File(output + File.separator + prefix + ".target.lookup");
+ featureFile = new File(output + File.separator + prefix + ".features");
+
+ alignmentFile = null;
+ if (packAlignments)
+ alignmentFile = new File(output + File.separator + prefix + ".alignments");
+
+ LOG.info("Allocated slice: {}", sourceFile.getAbsolutePath());
+ }
+
+ DataOutputStream getSourceOutput() throws IOException {
+ return getOutput(sourceFile);
+ }
+
+ DataOutputStream getTargetOutput() throws IOException {
+ return getOutput(targetFile);
+ }
+
+ DataOutputStream getTargetLookupOutput() throws IOException {
+ return getOutput(targetLookupFile);
+ }
+
+ DataOutputStream getFeatureOutput() throws IOException {
+ return getOutput(featureFile);
+ }
+
+ DataOutputStream getAlignmentOutput() throws IOException {
+ if (alignmentFile != null)
+ return getOutput(alignmentFile);
+ return null;
+ }
+
+ private DataOutputStream getOutput(File file) throws IOException {
+ if (file.createNewFile()) {
+ return new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
+ } else {
+ throw new RuntimeException("File doesn't exist: " + file.getName());
+ }
+ }
+
+ long getSize() {
+ return sourceFile.length() + targetFile.length() + featureFile.length();
+ }
+
+ @Override
+ public int compareTo(PackingFileTuple o) {
+ if (getSize() > o.getSize()) {
+ return -1;
+ } else if (getSize() < o.getSize()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
index 2fd2b3f,0000000..8f15b0e
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/LabelPhrases.java
@@@ -1,111 -1,0 +1,111 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import java.io.IOException;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.corpus.syntax.ArraySyntaxTree;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Finds labeling for a set of phrases.
- *
++ *
+ * @author Juri Ganitkevitch
+ */
+public class LabelPhrases {
+
+ private static final Logger LOG = LoggerFactory.getLogger(LabelPhrases.class);
+
+ /**
+ * Main method.
- *
++ *
+ * @param args names of the two grammars to be compared
+ * @throws IOException if there is an error reading the input grammars
+ */
+ public static void main(String[] args) throws IOException {
+
+ if (args.length < 1 || args[0].equals("-h")) {
+ System.err.println("Usage: " + LabelPhrases.class.toString());
+ System.err.println(" -p phrase_file phrase-sentence file to process");
+ System.err.println();
+ System.exit(-1);
+ }
+
+ String phrase_file_name = null;
+
+ for (int i = 0; i < args.length; i++) {
+ if ("-p".equals(args[i])) phrase_file_name = args[++i];
+ }
+ if (phrase_file_name == null) {
+ LOG.error("a phrase file is required for operation");
+ System.exit(-1);
+ }
+
- LineReader phrase_reader = new LineReader(phrase_file_name);
-
- while (phrase_reader.ready()) {
- String line = phrase_reader.readLine();
-
- String[] fields = line.split("\\t");
- if (fields.length != 3 || fields[2].equals("()")) {
- System.err.println("[FAIL] Empty parse in line:\t" + line);
- continue;
- }
-
- String[] phrase_strings = fields[0].split("\\s");
- int[] phrase_ids = new int[phrase_strings.length];
- for (int i = 0; i < phrase_strings.length; i++)
- phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
++ try (LineReader phrase_reader = new LineReader(phrase_file_name);) {
++ while (phrase_reader.ready()) {
++ String line = phrase_reader.readLine();
+
- ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
- int[] sentence_ids = syntax.getTerminals();
++ String[] fields = line.split("\\t");
++ if (fields.length != 3 || fields[2].equals("()")) {
++ System.err.println("[FAIL] Empty parse in line:\t" + line);
++ continue;
++ }
+
- int match_start = -1;
- int match_end = -1;
- for (int i = 0; i < sentence_ids.length; i++) {
- if (phrase_ids[0] == sentence_ids[i]) {
- match_start = i;
- int j = 0;
- while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
- j++;
- }
- if (j == phrase_ids.length) {
- match_end = i + j;
- break;
++ String[] phrase_strings = fields[0].split("\\s");
++ int[] phrase_ids = new int[phrase_strings.length];
++ for (int i = 0; i < phrase_strings.length; i++)
++ phrase_ids[i] = Vocabulary.id(phrase_strings[i]);
++
++ ArraySyntaxTree syntax = new ArraySyntaxTree(fields[2]);
++ int[] sentence_ids = syntax.getTerminals();
++
++ int match_start = -1;
++ int match_end = -1;
++ for (int i = 0; i < sentence_ids.length; i++) {
++ if (phrase_ids[0] == sentence_ids[i]) {
++ match_start = i;
++ int j = 0;
++ while (j < phrase_ids.length && phrase_ids[j] == sentence_ids[i + j]) {
++ j++;
++ }
++ if (j == phrase_ids.length) {
++ match_end = i + j;
++ break;
++ }
+ }
+ }
- }
+
- int label = syntax.getOneConstituent(match_start, match_end);
- if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
- if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
- if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
- if (label == 0) {
- System.err.println("[FAIL] No label found in line:\t" + line);
- continue;
- }
++ int label = syntax.getOneConstituent(match_start, match_end);
++ if (label == 0) label = syntax.getOneSingleConcatenation(match_start, match_end);
++ if (label == 0) label = syntax.getOneRightSideCCG(match_start, match_end);
++ if (label == 0) label = syntax.getOneLeftSideCCG(match_start, match_end);
++ if (label == 0) label = syntax.getOneDoubleConcatenation(match_start, match_end);
++ if (label == 0) {
++ System.err.println("[FAIL] No label found in line:\t" + line);
++ continue;
++ }
+
- System.out.println(Vocabulary.word(label) + "\t" + line);
++ System.out.println(Vocabulary.word(label) + "\t" + line);
++ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
index ecb2e6e,0000000..f73f02d
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
+++ b/joshua-core/src/main/java/org/apache/joshua/tools/TestSetFilter.java
@@@ -1,383 -1,0 +1,384 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.tools;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestSetFilter {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TestSetFilter.class);
+
+ private Filter filter = null;
+
+ // for caching of accepted rules
+ private String lastSourceSide;
+ private boolean acceptedLastSourceSide;
+
+ public int cached = 0;
+ public int RULE_LENGTH = 12;
+ public boolean verbose = false;
+ public boolean parallel = false;
+
+ private static final String DELIMITER = "|||";
+ private static final String DELIMITER_REGEX = " \\|\\|\\| ";
+ public static final String DELIM = String.format(" %s ", DELIMITER);
+ public static final Pattern P_DELIM = Pattern.compile(DELIMITER_REGEX);
+ private final String NT_REGEX = "\\[[^\\]]+?\\]";
+
+ public TestSetFilter() {
+ acceptedLastSourceSide = false;
+ lastSourceSide = null;
+ }
-
++
+ public String getFilterName() {
+ if (filter != null)
+ if (filter instanceof FastFilter)
+ return "fast";
+ else if (filter instanceof LooseFilter)
+ return "loose";
+ else
+ return "exact";
+ return "null";
+ }
+
+ public void setVerbose(boolean value) {
+ verbose = value;
+ }
+
+ public void setParallel(boolean value) {
+ parallel = value;
+ }
+
+ public void setFilter(String type) {
+ if (type.equals("fast"))
+ filter = new FastFilter();
+ else if (type.equals("exact"))
+ filter = new ExactFilter();
+ else if (type.equals("loose"))
+ filter = new LooseFilter();
+ else
+ throw new RuntimeException(String.format("Invalid filter type '%s'", type));
+ }
+
+ public void setRuleLength(int value) {
+ RULE_LENGTH = value;
+ }
+
+ private void loadTestSentences(String filename) throws IOException {
+ int count = 0;
+
+ try {
+ for (String line: new LineReader(filename)) {
+ filter.addSentence(line);
+ count++;
+ }
+ } catch (FileNotFoundException e) {
+ LOG.error(e.getMessage(), e);
+ }
+
+ if (verbose)
+ System.err.println(String.format("Added %d sentences.\n", count));
+ }
+
+ /**
- * Top-level filter, responsible for calling the fast or exact version. Takes the source side
++ * Top-level filter, responsible for calling the fast or exact version. Takes the source side
+ * of a rule and determines whether there is any sentence in the test set that can match it.
+ * @param sourceSide an input source sentence
+ * @return true if is any sentence in the test set can match the source input
+ */
+ public boolean inTestSet(String sourceSide) {
+ if (!sourceSide.equals(lastSourceSide)) {
+ lastSourceSide = sourceSide;
+ acceptedLastSourceSide = filter.permits(sourceSide);
+ } else {
+ cached++;
+ }
+
+ return acceptedLastSourceSide;
+ }
-
++
+ /**
+ * Determines whether a rule is an abstract rule. An abstract rule is one that has no terminals on
+ * its source side.
- *
++ *
+ * If the rule is abstract, the rule's arity is returned. Otherwise, 0 is returned.
+ */
+ private boolean isAbstract(String source) {
+ int nonterminalCount = 0;
+ for (String t : source.split("\\s+")) {
+ if (!t.matches(NT_REGEX))
+ return false;
+ nonterminalCount++;
+ }
+ return nonterminalCount != 0;
+ }
+
+ private interface Filter {
+ /* Tell the filter about a sentence in the test set being filtered to */
+ public void addSentence(String sentence);
-
++
+ /* Returns true if the filter permits the specified source side */
+ public boolean permits(String sourceSide);
+ }
+
+ private class FastFilter implements Filter {
+ private Set<String> ngrams = null;
+
+ public FastFilter() {
+ ngrams = new HashSet<String>();
+ }
-
++
+ @Override
+ public boolean permits(String source) {
+ for (String chunk : source.split(NT_REGEX)) {
+ chunk = chunk.trim();
+ /* Important: you need to make sure the string isn't empty. */
+ if (!chunk.equals("") && !ngrams.contains(chunk))
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void addSentence(String sentence) {
+ String[] tokens = sentence.trim().split("\\s+");
+ int maxOrder = RULE_LENGTH < tokens.length ? RULE_LENGTH : tokens.length;
+ for (int order = 1; order <= maxOrder; order++) {
+ for (int start = 0; start < tokens.length - order + 1; start++)
+ ngrams.add(createNGram(tokens, start, order));
+ }
+ }
+
+ private String createNGram(String[] tokens, int start, int order) {
+ if (order < 1 || start + order > tokens.length) {
+ return "";
+ }
+ String result = tokens[start];
+ for (int i = 1; i < order; i++)
+ result += " " + tokens[start + i];
+ return result;
+ }
+ }
+
+ private class LooseFilter implements Filter {
+ List<String> testSentences = null;
+
+ public LooseFilter() {
+ testSentences = new ArrayList<String>();
+ }
-
++
+ @Override
+ public void addSentence(String source) {
+ testSentences.add(source);
+ }
+
+ @Override
+ public boolean permits(String source) {
+ Pattern pattern = getPattern(source);
+ for (String testSentence : testSentences) {
+ if (pattern.matcher(testSentence).find()) {
+ return true;
+ }
+ }
+ return isAbstract(source);
+ }
+
+ protected Pattern getPattern(String source) {
+ String pattern = source;
+ pattern = pattern.replaceAll(String.format("\\s*%s\\s*", NT_REGEX), ".+");
+ pattern = pattern.replaceAll("\\s+", ".*");
+// System.err.println(String.format("PATTERN(%s) = %s", source, pattern));
+ return Pattern.compile(pattern);
+ }
+ }
+
+ /**
+ * This class is the same as LooseFilter except with a tighter regex for matching rules.
+ */
+ private class ExactFilter implements Filter {
+ private FastFilter fastFilter = null;
+ private Map<String, Set<Integer>> sentencesByWord;
+ List<String> testSentences = null;
-
++
+ public ExactFilter() {
+ fastFilter = new FastFilter();
+ sentencesByWord = new HashMap<String, Set<Integer>>();
+ testSentences = new ArrayList<String>();
+ }
-
++
+ @Override
+ public void addSentence(String source) {
+ fastFilter.addSentence(source);
+ addSentenceToWordHash(source, testSentences.size());
+ testSentences.add(source);
+ }
+
+ /**
+ * Always permit abstract rules. Otherwise, query the fast filter, and if that passes, apply
- *
++ *
+ */
+ @Override
+ public boolean permits(String sourceSide) {
+ if (isAbstract(sourceSide))
+ return true;
-
++
+ if (fastFilter.permits(sourceSide)) {
+ Pattern pattern = getPattern(sourceSide);
+ for (int i : getSentencesForRule(sourceSide)) {
+ if (pattern.matcher(testSentences.get(i)).find()) {
+ return true;
+ }
+ }
- }
++ }
+ return false;
+ }
-
++
+ protected Pattern getPattern(String source) {
+ String pattern = Pattern.quote(source);
+ pattern = pattern.replaceAll(NT_REGEX, "\\\\E.+\\\\Q");
+ pattern = pattern.replaceAll("\\\\Q\\\\E", "");
+ pattern = "(?:^|\\s)" + pattern + "(?:$|\\s)";
+ return Pattern.compile(pattern);
+ }
-
++
+ /*
+ * Map words to all the sentences they appear in.
+ */
+ private void addSentenceToWordHash(String sentence, int index) {
+ String[] tokens = sentence.split("\\s+");
+ for (String t : tokens) {
+ if (! sentencesByWord.containsKey(t))
+ sentencesByWord.put(t, new HashSet<Integer>());
+ sentencesByWord.get(t).add(index);
+ }
+ }
-
++
+ private Set<Integer> getSentencesForRule(String source) {
+ Set<Integer> sentences = null;
+ for (String token : source.split("\\s+")) {
+ if (!token.matches(NT_REGEX)) {
+ if (sentencesByWord.containsKey(token)) {
+ if (sentences == null)
+ sentences = new HashSet<Integer>(sentencesByWord.get(token));
+ else
+ sentences.retainAll(sentencesByWord.get(token));
+ }
+ }
+ }
-
++
+ return sentences;
+ }
+ }
+
+ public static void main(String[] argv) throws IOException {
+ // do some setup
+ if (argv.length < 1) {
+ System.err.println("usage: TestSetFilter [-v|-p|-f|-e|-l|-n N|-g grammar] test_set1 [test_set2 ...]");
+ System.err.println(" -g grammar file (can also be on STDIN)");
+ System.err.println(" -v verbose output");
+ System.err.println(" -p parallel compatibility");
+ System.err.println(" -f fast mode (default)");
+ System.err.println(" -e exact mode (slower)");
+ System.err.println(" -l loose mode");
+ System.err.println(" -n max n-gram to compare to (default 12)");
+ return;
+ }
-
++
+ String grammarFile = null;
+
+ TestSetFilter filter = new TestSetFilter();
+
+ for (int i = 0; i < argv.length; i++) {
+ if (argv[i].equals("-v")) {
+ filter.setVerbose(true);
+ continue;
+ } else if (argv[i].equals("-p")) {
+ filter.setParallel(true);
+ continue;
+ } else if (argv[i].equals("-g")) {
+ grammarFile = argv[++i];
+ continue;
+ } else if (argv[i].equals("-f")) {
+ filter.setFilter("fast");
+ continue;
+ } else if (argv[i].equals("-e")) {
+ filter.setFilter("exact");
+ continue;
+ } else if (argv[i].equals("-l")) {
+ filter.setFilter("loose");
+ continue;
+ } else if (argv[i].equals("-n")) {
+ filter.setRuleLength(Integer.parseInt(argv[i + 1]));
+ i++;
+ continue;
+ }
+
+ filter.loadTestSentences(argv[i]);
+ }
+
+ int rulesIn = 0;
+ int rulesOut = 0;
+ if (filter.verbose) {
+ System.err.println(String.format("Filtering rules with the %s filter...", filter.getFilterName()));
+// System.err.println("Using at max " + filter.RULE_LENGTH + " n-grams...");
+ }
- LineReader reader = (grammarFile != null)
++ try(LineReader reader = (grammarFile != null)
+ ? new LineReader(grammarFile, filter.verbose)
- : new LineReader(System.in);
- for (String rule: reader) {
- rulesIn++;
-
- String[] parts = P_DELIM.split(rule);
- if (parts.length >= 4) {
- // the source is the second field for thrax grammars, first field for phrasal ones
- String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
- if (filter.inTestSet(source)) {
- System.out.println(rule);
- if (filter.parallel)
++ : new LineReader(System.in);) {
++ for (String rule: reader) {
++ rulesIn++;
++
++ String[] parts = P_DELIM.split(rule);
++ if (parts.length >= 4) {
++ // the source is the second field for thrax grammars, first field for phrasal ones
++ String source = rule.startsWith("[") ? parts[1].trim() : parts[0].trim();
++ if (filter.inTestSet(source)) {
++ System.out.println(rule);
++ if (filter.parallel)
++ System.out.flush();
++ rulesOut++;
++ } else if (filter.parallel) {
++ System.out.println("");
+ System.out.flush();
- rulesOut++;
- } else if (filter.parallel) {
- System.out.println("");
- System.out.flush();
++ }
+ }
+ }
- }
- if (filter.verbose) {
- System.err.println("[INFO] Total rules read: " + rulesIn);
- System.err.println("[INFO] Rules kept: " + rulesOut);
- System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
- System.err.println("[INFO] cached queries: " + filter.cached);
- }
++ if (filter.verbose) {
++ System.err.println("[INFO] Total rules read: " + rulesIn);
++ System.err.println("[INFO] Rules kept: " + rulesOut);
++ System.err.println("[INFO] Rules dropped: " + (rulesIn - rulesOut));
++ System.err.println("[INFO] cached queries: " + filter.cached);
++ }
+
- return;
++ return;
++ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/Constants.java
index bcabfe4,0000000..669023b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/Constants.java
@@@ -1,44 -1,0 +1,45 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
+/***
+ * One day, all constants should be moved here (many are in Vocabulary).
- *
++ *
+ * @author Matt Post post@cs.jhu.edu
+ */
+
+public final class Constants {
- public static String defaultNT = "[X]";
++ public static final String defaultNT = "[X]";
+
+ public static final String START_SYM = "<s>";
+ public static final String STOP_SYM = "</s>";
+ public static final String UNKNOWN_WORD = "<unk>";
-
++
+ public static final String fieldDelimiterPattern = "\\s\\|{3}\\s";
+ public static final String fieldDelimiter = " ||| ";
++
+ public static final String spaceSeparator = "\\s+";
-
++
+ public static final String NT_REGEX = "\\[[^\\]]+?\\]";
-
++
+ public static final String TM_PREFIX = "tm";
-
++
+ public static final String labeledFeatureSeparator = "=";
-
++
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
index a36b07f,0000000..0f13e6a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/FileUtility.java
@@@ -1,318 -1,0 +1,63 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
- import java.io.BufferedReader;
+import java.io.BufferedWriter;
- import java.io.Closeable;
+import java.io.File;
+import java.io.FileDescriptor;
- import java.io.FileInputStream;
+import java.io.FileOutputStream;
- import java.io.FileReader;
+import java.io.IOException;
- import java.io.InputStream;
- import java.io.OutputStream;
+import java.io.OutputStreamWriter;
- import java.nio.charset.Charset;
- import java.util.LinkedList;
- import java.util.List;
- import java.util.Scanner;
++import java.nio.charset.StandardCharsets;
+
+/**
+ * utility functions for file operations
- *
++ *
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @since 28 February 2009
+ */
+public class FileUtility {
- public static String DEFAULT_ENCODING = "UTF-8";
-
- /*
- * Note: charset name is case-agnostic "UTF-8" is the canonical name "UTF8", "unicode-1-1-utf-8"
- * are aliases Java doesn't distinguish utf8 vs UTF-8 like Perl does
- */
- private static final Charset FILE_ENCODING = Charset.forName(DEFAULT_ENCODING);
+
+ /**
+ * Warning, will truncate/overwrite existing files
+ * @param filename a file for which to obtain a writer
+ * @return the buffered writer object
+ * @throws IOException if there is a problem reading the inout file
+ */
+ public static BufferedWriter getWriteFileStream(String filename) throws IOException {
+ return new BufferedWriter(new OutputStreamWriter(
+ // TODO: add GZIP
+ filename.equals("-") ? new FileOutputStream(FileDescriptor.out) : new FileOutputStream(
- filename, false), FILE_ENCODING));
- }
-
- /**
- * Recursively delete the specified file or directory.
- *
- * @param f File or directory to delete
- * @return <code>true</code> if the specified file or directory was deleted, <code>false</code>
- * otherwise
- */
- public static boolean deleteRecursively(File f) {
- if (null != f) {
- if (f.isDirectory())
- for (File child : f.listFiles())
- deleteRecursively(child);
- return f.delete();
- } else {
- return false;
- }
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes, overwriting the old file if present.
- *
- * @param data The integer array to write to disk.
- * @param filename The filename where the data should be written.
- * @throws IOException if there is a problem writing to the output file
- * @return the FileOutputStream on which the bytes were written
- */
- public static FileOutputStream writeBytes(int[] data, String filename) throws IOException {
- FileOutputStream out = new FileOutputStream(filename, false);
- writeBytes(data, out);
- return out;
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes.
- *
- * @param data The integer array to write to disk.
- * @param out The output stream where the data should be written.
- * @throws IOException if there is a problem writing bytes
- */
- public static void writeBytes(int[] data, OutputStream out) throws IOException {
-
- byte[] b = new byte[4];
-
- for (int word : data) {
- b[0] = (byte) ((word >>> 24) & 0xFF);
- b[1] = (byte) ((word >>> 16) & 0xFF);
- b[2] = (byte) ((word >>> 8) & 0xFF);
- b[3] = (byte) ((word >>> 0) & 0xFF);
-
- out.write(b);
- }
- }
-
- public static void copyFile(String srFile, String dtFile) throws IOException {
- try {
- File f1 = new File(srFile);
- File f2 = new File(dtFile);
- copyFile(f1, f2);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public static void copyFile(File srFile, File dtFile) throws IOException {
- try {
-
- InputStream in = new FileInputStream(srFile);
-
- // For Append the file.
- // OutputStream out = new FileOutputStream(f2,true);
-
- // For Overwrite the file.
- OutputStream out = new FileOutputStream(dtFile);
-
- byte[] buf = new byte[1024];
- int len;
- while ((len = in.read(buf)) > 0) {
- out.write(buf, 0, len);
- }
- in.close();
- out.close();
- System.out.println("File copied.");
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- static public boolean deleteFile(String fileName) {
-
- File f = new File(fileName);
-
- // Make sure the file or directory exists and isn't write protected
- if (!f.exists())
- System.out.println("Delete: no such file or directory: " + fileName);
-
- if (!f.canWrite())
- System.out.println("Delete: write protected: " + fileName);
-
- // If it is a directory, make sure it is empty
- if (f.isDirectory()) {
- String[] files = f.list();
- if (files.length > 0)
- System.out.println("Delete: directory not empty: " + fileName);
- }
-
- // Attempt to delete it
- boolean success = f.delete();
-
- if (!success)
- System.out.println("Delete: deletion failed");
-
- return success;
-
++ filename, false), StandardCharsets.UTF_8));
+ }
+
+ /**
+ * Returns the base directory of the file. For example, dirname('/usr/local/bin/emacs') ->
+ * '/usr/local/bin'
+ * @param fileName the input path
+ * @return the parent path
+ */
+ static public String dirname(String fileName) {
+ if (fileName.indexOf(File.separator) != -1)
+ return fileName.substring(0, fileName.lastIndexOf(File.separator));
+
+ return ".";
+ }
-
- public static void createFolderIfNotExisting(String folderName) {
- File f = new File(folderName);
- if (!f.isDirectory()) {
- System.out.println(" createFolderIfNotExisting -- Making directory: " + folderName);
- f.mkdirs();
- } else {
- System.out.println(" createFolderIfNotExisting -- Directory: " + folderName
- + " already existed");
- }
- }
-
- public static void closeCloseableIfNotNull(Closeable fileWriter) {
- if (fileWriter != null) {
- try {
- fileWriter.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }
-
- /**
- * Returns the directory were the program has been started,
- * the base directory you will implicitly get when specifying no
- * full path when e.g. opening a file
- * @return the current 'user.dir'
- */
- public static String getWorkingDirectory() {
- return System.getProperty("user.dir");
- }
-
- /**
- * Method to handle standard IO exceptions. catch (Exception e) {Utility.handleIO_exception(e);}
- * @param e an input {@link java.lang.Exception}
- */
- public static void handleExceptions(Exception e) {
- throw new RuntimeException(e);
- }
-
- /**
- * Convenience method to get a full file as a String
- * @param file the input {@link java.io.File}
- * @return The file as a String. Lines are separated by newline character.
- */
- public static String getFileAsString(File file) {
- String result = "";
- List<String> lines = getLines(file, true);
- for (int i = 0; i < lines.size() - 1; i++) {
- result += lines.get(i) + "\n";
- }
- if (!lines.isEmpty()) {
- result += lines.get(lines.size() - 1);
- }
- return result;
- }
-
- /**
- * This method returns a List of String. Each element of the list corresponds to a line from the
- * input file. The boolean keepDuplicates in the input determines if duplicate lines are allowed
- * in the output LinkedList or not.
- * @param file the input file
- * @param keepDuplicates whether to retain duplicate lines
- * @return a {@link java.util.List} of lines
- */
- static public List<String> getLines(File file, boolean keepDuplicates) {
- LinkedList<String> list = new LinkedList<String>();
- String line = "";
- try {
- BufferedReader InputReader = new BufferedReader(new FileReader(file));
- for (;;) { // this loop writes writes in a Sting each sentence of
- // the file and process it
- int current = InputReader.read();
- if (current == -1 || current == '\n') {
- if (keepDuplicates || !list.contains(line))
- list.add(line);
- line = "";
- if (current == -1)
- break; // EOF
- } else
- line += (char) current;
- }
- InputReader.close();
- } catch (Exception e) {
- handleExceptions(e);
- }
- return list;
- }
-
- /**
- * Returns a Scanner of the inputFile using a specific encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @param encoding the encoding to use within the Scanner
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile, String encoding) {
- Scanner scan = null;
- try {
- scan = new Scanner(inputFile, encoding);
- } catch (IOException e) {
- FileUtility.handleExceptions(e);
- }
- return scan;
- }
-
- /**
- * Returns a Scanner of the inputFile using default encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile) {
- return getScanner(inputFile, DEFAULT_ENCODING);
- }
-
- static public String getFirstLineInFile(File inputFile) {
- Scanner scan = FileUtility.getScanner(inputFile);
- if (!scan.hasNextLine())
- return null;
- String line = scan.nextLine();
- scan.close();
- return line;
- }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
index afb5af1,0000000..14154e8
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/ListUtil.java
@@@ -1,95 -1,0 +1,44 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util;
+
+import java.util.List;
+
+public class ListUtil {
+
- /**
- * Static method to generate a list representation for an ArrayList of Strings S1,...,Sn
- *
- * @param list A list of Strings
- * @return A String consisting of the original list of strings concatenated and separated by
- * commas, and enclosed by square brackets i.e. '[S1,S2,...,Sn]'
- */
- public static String stringListString(List<String> list) {
-
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
-
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
-
- result += "]";
-
- return result;
-
- }
-
- public static <E> String objectListString(List<E> list) {
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
- result += "]";
- return result;
- }
-
- /**
- * Static method to generate a simple concatenated representation for an ArrayList of Strings
- * S1,...,Sn
- *
- * @param list A list of Strings
- * @return todo
- */
- public static String stringListStringWithoutBrackets(List<String> list) {
- return stringListStringWithoutBracketsWithSpecifiedSeparator(list, " ");
- }
-
+ public static String stringListStringWithoutBracketsCommaSeparated(List<String> list) {
+ return stringListStringWithoutBracketsWithSpecifiedSeparator(list, ",");
+ }
+
+ public static String stringListStringWithoutBracketsWithSpecifiedSeparator(List<String> list,
+ String separator) {
+
+ String result = "";
+ for (int i = 0; i < list.size() - 1; i++) {
+ result += list.get(i) + separator;
+ }
+
+ if (list.size() > 0) {
+ // get the generated word for the last target position
+ result += list.get(list.size() - 1);
+ }
+
+ return result;
-
+ }
-
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
index ad2910c,0000000..d9bab66
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
@@@ -1,235 -1,0 +1,236 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.encoding;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.TreeMap;
+
+import org.apache.joshua.util.io.LineReader;
+
+public class Analyzer {
+
+ private TreeMap<Float, Integer> histogram;
+ private int total;
+
+ public Analyzer() {
+ histogram = new TreeMap<Float, Integer>();
+ initialize();
+ }
+
+ public void initialize() {
+ histogram.clear();
+ // TODO: drop zero bucket; we won't encode zero-valued features anyway.
+ histogram.put(0.0f, 0);
+ total = 0;
+ }
+
+ public void add(float key) {
+ if (histogram.containsKey(key))
+ histogram.put(key, histogram.get(key) + 1);
+ else
+ histogram.put(key, 1);
+ total++;
+ }
+
+ public float[] quantize(int num_bits) {
+ float[] buckets = new float[1 << num_bits];
+
+ // We make sure that 0.0f always has its own bucket, so the bucket
+ // size is determined excluding the zero values.
+ int size = (total - histogram.get(0.0f)) / (buckets.length - 1);
+ buckets[0] = 0.0f;
+
+ int old_size = -1;
+ while (old_size != size) {
+ int sum = 0;
+ int count = buckets.length - 1;
+ for (float key : histogram.keySet()) {
+ int entry_count = histogram.get(key);
+ if (entry_count < size && key != 0)
+ sum += entry_count;
+ else
+ count--;
+ }
+ old_size = size;
+ size = sum / count;
+ }
+
+ float last_key = Float.MAX_VALUE;
+
+ int index = 1;
+ int count = 0;
+ float sum = 0.0f;
+
+ int value;
+ for (float key : histogram.keySet()) {
+ value = histogram.get(key);
+ // Special bucket termination cases: zero boundary and histogram spikes.
+ if (key == 0 || (last_key < 0 && key > 0) || (value >= size)) {
+ // If the count is not 0, i.e. there were negative values, we should
+ // not bucket them with the positive ones. Close out the bucket now.
+ if (count != 0 && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
++ buckets[index++] = sum / count;
+ count = 0;
+ sum = 0;
+ }
+ if (key == 0)
+ continue;
+ }
+ count += value;
+ sum += key * value;
+ // Check if the bucket is full.
+ if (count >= size && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
++ buckets[index++] = sum / count;
+ count = 0;
+ sum = 0;
+ }
+ last_key = key;
+ }
+ if (count > 0 && index < buckets.length - 1)
- buckets[index++] = (float) sum / count;
-
++ buckets[index++] = sum / count;
++
+ float[] shortened = new float[index];
+ for (int i = 0; i < shortened.length; ++i)
+ shortened[i] = buckets[i];
+ return shortened;
+ }
+
+ public boolean isBoolean() {
+ for (float value : histogram.keySet())
+ if (value != 0 && value != 1)
+ return false;
+ return true;
+ }
+
+ public boolean isByte() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Byte.MIN_VALUE || value > Byte.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isShort() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Short.MIN_VALUE || value > Short.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isChar() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value || value < Character.MIN_VALUE || value > Character.MAX_VALUE)
+ return false;
+ return true;
+ }
+
+ public boolean isInt() {
+ for (float value : histogram.keySet())
+ if (Math.ceil(value) != value)
+ return false;
+ return true;
+ }
+
+ public boolean is8Bit() {
+ return (histogram.keySet().size() <= 256);
+ }
+
+ public FloatEncoder inferUncompressedType() {
+ if (isBoolean())
+ return PrimitiveFloatEncoder.BOOLEAN;
+ if (isByte())
+ return PrimitiveFloatEncoder.BYTE;
+ if (is8Bit())
+ return new EightBitQuantizer(this.quantize(8));
+ if (isChar())
+ return PrimitiveFloatEncoder.CHAR;
+ if (isShort())
+ return PrimitiveFloatEncoder.SHORT;
+ if (isInt())
+ return PrimitiveFloatEncoder.INT;
+ return PrimitiveFloatEncoder.FLOAT;
+ }
-
++
+ public FloatEncoder inferType(int bits) {
+ if (isBoolean())
+ return PrimitiveFloatEncoder.BOOLEAN;
+ if (isByte())
+ return PrimitiveFloatEncoder.BYTE;
+ if (bits == 8 || is8Bit())
+ return new EightBitQuantizer(this.quantize(8));
+ // TODO: Could add sub-8-bit encoding here (or larger).
+ if (isChar())
+ return PrimitiveFloatEncoder.CHAR;
+ if (isShort())
+ return PrimitiveFloatEncoder.SHORT;
+ if (isInt())
+ return PrimitiveFloatEncoder.INT;
+ return PrimitiveFloatEncoder.FLOAT;
+ }
+
+ public String toString(String label) {
+ StringBuilder sb = new StringBuilder();
+ for (float val : histogram.keySet())
+ sb.append(label + "\t" + String.format("%.5f", val) + "\t" + histogram.get(val) + "\n");
+ return sb.toString();
+ }
-
++
+ public static void main(String[] args) throws IOException {
- LineReader reader = new LineReader(args[0]);
- ArrayList<Float> s = new ArrayList<Float>();
-
- System.out.println("Initialized.");
- while (reader.hasNext())
- s.add(Float.parseFloat(reader.next().trim()));
- System.out.println("Data read.");
- int n = s.size();
- byte[] c = new byte[n];
- ByteBuffer b = ByteBuffer.wrap(c);
- Analyzer q = new Analyzer();
-
- q.initialize();
- for (int i = 0; i < n; i++)
- q.add(s.get(i));
- EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
- System.out.println("Quantizer learned.");
-
- for (int i = 0; i < n; i++)
- eb.write(b, s.get(i));
- b.rewind();
- System.out.println("Quantization complete.");
-
- float avg_error = 0;
- float error = 0;
- int count = 0;
- for (int i = -4; i < n - 4; i++) {
- float coded = eb.read(b, i);
- if (s.get(i + 4) != 0) {
- error = Math.abs(s.get(i + 4) - coded);
- avg_error += error;
- count++;
++ try (LineReader reader = new LineReader(args[0]);) {
++ ArrayList<Float> s = new ArrayList<Float>();
++
++ System.out.println("Initialized.");
++ while (reader.hasNext())
++ s.add(Float.parseFloat(reader.next().trim()));
++ System.out.println("Data read.");
++ int n = s.size();
++ byte[] c = new byte[n];
++ ByteBuffer b = ByteBuffer.wrap(c);
++ Analyzer q = new Analyzer();
++
++ q.initialize();
++ for (int i = 0; i < n; i++)
++ q.add(s.get(i));
++ EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
++ System.out.println("Quantizer learned.");
++
++ for (int i = 0; i < n; i++)
++ eb.write(b, s.get(i));
++ b.rewind();
++ System.out.println("Quantization complete.");
++
++ float avg_error = 0;
++ float error = 0;
++ int count = 0;
++ for (int i = -4; i < n - 4; i++) {
++ float coded = eb.read(b, i);
++ if (s.get(i + 4) != 0) {
++ error = Math.abs(s.get(i + 4) - coded);
++ avg_error += error;
++ count++;
++ }
+ }
- }
- avg_error /= count;
- System.out.println("Evaluation complete.");
++ avg_error /= count;
++ System.out.println("Evaluation complete.");
+
- System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
++ System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
++ }
+ }
+}
[12/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/metrics/SARI.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/metrics/SARI.java
index 129e4af,0000000..9ee3af3
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/metrics/SARI.java
+++ b/joshua-core/src/main/java/org/apache/joshua/metrics/SARI.java
@@@ -1,681 -1,0 +1,630 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied. See the License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package org.apache.joshua.metrics;
+
- // Changed PROCore.java (text normalization function) and EvaluationMetric too
-
- import java.util.Map;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.logging.Logger;
-
+import java.io.BufferedReader;
- import java.io.IOException;
- import java.io.InputStreamReader;
+import java.io.File;
+import java.io.FileInputStream;
++import java.io.IOException;
+import java.io.InputStream;
++import java.io.InputStreamReader;
++import java.util.HashMap;
++import java.util.Iterator;
++
++// Changed PROCore.java (text normalization function) and EvaluationMetric too
++
++import java.util.Map;
++import java.util.logging.Logger;
+
+/***
+ * Implementation of the SARI metric for text-to-text correction.
- *
++ *
+ * \@article{xu2016optimizing,
+ * title={Optimizing statistical machine translation for text simplification},
+ * author={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris},
+ * journal={Transactions of the Association for Computational Linguistics},
+ * volume={4},
+ * year={2016}}
- *
++ *
+ * @author Wei Xu
+ */
+public class SARI extends EvaluationMetric {
+ private static final Logger logger = Logger.getLogger(SARI.class.getName());
+
+ // The maximum n-gram we care about
+ protected int maxGramLength;
+ protected String[] srcSentences;
+ protected double[] weights;
+ protected HashMap<String, Integer>[][] refNgramCounts;
+ protected HashMap<String, Integer>[][] srcNgramCounts;
+
+ /*
+ * You already have access to these data members of the parent class (EvaluationMetric): int
+ * numSentences; number of sentences in the MERT set int refsPerSen; number of references per
+ * sentence String[][] refSentences; refSentences[i][r] stores the r'th reference of the i'th
+ * source sentence (both indices are 0-based)
+ */
+
+ public SARI(String[] Metric_options) {
+ int mxGrmLn = Integer.parseInt(Metric_options[0]);
+ if (mxGrmLn >= 1) {
+ maxGramLength = mxGrmLn;
+ } else {
+ logger.severe("Maximum gram length must be positive");
+ System.exit(1);
+ }
+
+ try {
+ loadSources(Metric_options[1]);
+ } catch (IOException e) {
+ logger.severe("Error loading the source sentences from " + Metric_options[1]);
+ System.exit(1);
+ }
+
+ initialize(); // set the data members of the metric
+
+ }
+
++ @Override
+ protected void initialize() {
+ metricName = "SARI";
+ toBeMinimized = false;
+ suffStatsCount = StatIndex.values().length * maxGramLength + 1;
+
+ set_weightsArray();
+ set_refNgramCounts();
+ set_srcNgramCounts();
+
+ }
+
++ @Override
+ public double bestPossibleScore() {
+ return 1.0;
+ }
+
++ @Override
+ public double worstPossibleScore() {
+ return 0.0;
+ }
+
+ /**
+ * Sets the BLEU weights for each n-gram level to uniform.
+ */
+ protected void set_weightsArray() {
+ weights = new double[1 + maxGramLength];
+ for (int n = 1; n <= maxGramLength; ++n) {
+ weights[n] = 1.0 / maxGramLength;
+ }
+ }
+
+ /**
+ * Computes the sum of ngram counts in references for each sentence (storing them in
+ * <code>refNgramCounts</code>), which are used for clipping n-gram counts.
+ */
+ protected void set_refNgramCounts() {
+ @SuppressWarnings("unchecked")
+
+ HashMap<String, Integer>[][] temp_HMA = new HashMap[numSentences][maxGramLength];
+ refNgramCounts = temp_HMA;
+
+ String gram = "";
+ int oldCount = 0, nextCount = 0;
+
+ for (int i = 0; i < numSentences; ++i) {
+ refNgramCounts[i] = getNgramCountsArray(refSentences[i][0]);
+ // initialize to ngramCounts[n] of the first reference translation...
+
+ // ...and update as necessary from the other reference translations
+ for (int r = 1; r < refsPerSen; ++r) {
+
+ HashMap<String, Integer>[] nextNgramCounts = getNgramCountsArray(refSentences[i][r]);
+
+ for (int n = 1; n <= maxGramLength; ++n) {
+
+ Iterator<String> it = (nextNgramCounts[n].keySet()).iterator();
+
+ while (it.hasNext()) {
+ gram = it.next();
+ nextCount = nextNgramCounts[n].get(gram);
+
+ if (refNgramCounts[i][n].containsKey(gram)) { // update if necessary
+ oldCount = refNgramCounts[i][n].get(gram);
+ refNgramCounts[i][n].put(gram, oldCount + nextCount);
+ } else { // add it
+ refNgramCounts[i][n].put(gram, nextCount);
+ }
+
+ }
+
+ } // for (n)
+
+ } // for (r)
+
+ } // for (i)
+
+ }
+
+ protected void set_srcNgramCounts() {
+ @SuppressWarnings("unchecked")
+
+ HashMap<String, Integer>[][] temp_HMA = new HashMap[numSentences][maxGramLength];
+ srcNgramCounts = temp_HMA;
+
+ for (int i = 0; i < numSentences; ++i) {
+ srcNgramCounts[i] = getNgramCountsArray(srcSentences[i]);
+ } // for (i)
+ }
+
+ // set contents of stats[] here!
++ @Override
+ public int[] suffStats(String cand_str, int i) {
+ int[] stats = new int[suffStatsCount];
+
+ HashMap<String, Integer>[] candNgramCounts = getNgramCountsArray(cand_str);
+
+ for (int n = 1; n <= maxGramLength; ++n) {
+
+ // ADD OPERATIONS
- HashMap cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
- HashMap cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
- HashMap ref_sub_src = substractHashMap(refNgramCounts[i][n], srcNgramCounts[i][n]);
++ HashMap<String, Integer> cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
++ HashMap<String, Integer> cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
++ HashMap<String, Integer> ref_sub_src = substractHashMap(refNgramCounts[i][n], srcNgramCounts[i][n]);
+
+ stats[StatIndex.values().length * (n - 1)
+ + StatIndex.ADDBOTH.ordinal()] = cand_and_ref_sub_src.keySet().size();
+ stats[StatIndex.values().length * (n - 1) + StatIndex.ADDCAND.ordinal()] = cand_sub_src
+ .keySet().size();
+ stats[StatIndex.values().length * (n - 1) + StatIndex.ADDREF.ordinal()] = ref_sub_src.keySet()
+ .size();
+
+ // System.out.println("src_and_cand_sub_ref" + cand_and_ref_sub_src +
+ // cand_and_ref_sub_src.keySet().size());
+ // System.out.println("cand_sub_src" + cand_sub_src + cand_sub_src.keySet().size());
+ // System.out.println("ref_sub_src" + ref_sub_src + ref_sub_src.keySet().size());
+
+ // DELETION OPERATIONS
- HashMap src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
++ HashMap<String, Integer> src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
++ refsPerSen, refsPerSen);
++ HashMap<String, Integer> src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
++ refsPerSen, 1);
++ HashMap<String, Integer> src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
+
+ stats[StatIndex.values().length * (n - 1) + StatIndex.DELBOTH.ordinal()] = sumHashMapByValues(
+ src_sub_cand_sub_ref);
+ stats[StatIndex.values().length * (n - 1) + StatIndex.DELCAND.ordinal()] = sumHashMapByValues(
+ src_sub_cand);
+ stats[StatIndex.values().length * (n - 1) + StatIndex.DELREF.ordinal()] = sumHashMapByValues(
+ src_sub_ref);
+
+ // System.out.println("src_sub_cand_sub_ref" + src_sub_cand_sub_ref +
+ // sumHashMapByValues(src_sub_cand_sub_ref));
+ // System.out.println("src_sub_cand" + src_sub_cand + sumHashMapByValues(src_sub_cand));
+ // System.out.println("src_sub_ref" + src_sub_ref + sumHashMapByValues(src_sub_ref));
+
+ stats[StatIndex.values().length * (n - 1) + StatIndex.DELREF.ordinal()] = src_sub_ref.keySet()
- .size() * this.refsPerSen;
++ .size() * refsPerSen;
+
+ // KEEP OPERATIONS
- HashMap src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
++ HashMap<String, Integer> src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
++ refsPerSen, refsPerSen);
++ HashMap<String, Integer> src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
++ refsPerSen, 1);
++ HashMap<String, Integer> src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
+
+ stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPBOTH.ordinal()] = sumHashMapByValues(src_and_cand_and_ref);
+ stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPCAND.ordinal()] = sumHashMapByValues(src_and_cand);
+ stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPREF.ordinal()] = sumHashMapByValues(
+ src_and_ref);
+
+ stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPBOTH.ordinal()] = (int) (1000000
+ * sumHashMapByDoubleValues(divideHashMap(src_and_cand_and_ref, src_and_cand)));
+ stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPCAND.ordinal()] = (int) sumHashMapByDoubleValues(
+ divideHashMap(src_and_cand_and_ref, src_and_ref));
+ stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPREF.ordinal()] = src_and_ref
+ .keySet().size();
-
- // System.out.println("src_and_cand_and_ref" + src_and_cand_and_ref);
- // System.out.println("src_and_cand" + src_and_cand);
- // System.out.println("src_and_ref" + src_and_ref);
-
- // stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPBOTH2.ordinal()] = (int)
- // sumHashMapByDoubleValues(divideHashMap(src_and_cand_and_ref,src_and_ref)) * 100000000 /
- // src_and_ref.keySet().size() ;
- // stats[StatIndex.values().length * (n - 1) + StatIndex.KEEPREF.ordinal()] =
- // src_and_ref.keySet().size() * 8;
-
- // System.out.println("src_and_cand_and_ref" + src_and_cand_and_ref);
- // System.out.println("src_and_cand" + src_and_cand);
- // System.out.println("divide" + divideHashMap(src_and_cand_and_ref,src_and_cand));
- // System.out.println(sumHashMapByDoubleValues(divideHashMap(src_and_cand_and_ref,src_and_cand)));
-
+ }
-
- int n = 1;
-
- // System.out.println("CAND: " + candNgramCounts[n]);
- // System.out.println("SRC: " + srcNgramCounts[i][n]);
- // System.out.println("REF: " + refNgramCounts[i][n]);
-
- HashMap src_and_cand = intersectHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_and_ref = intersectHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_and_cand_and_ref = intersectHashMap(src_and_cand, src_and_ref, 1, 1);
- // System.out.println("SRC&CAND&REF : " + src_and_cand_and_ref);
-
- HashMap cand_sub_src = substractHashMap(candNgramCounts[n], srcNgramCounts[i][n]);
- HashMap cand_and_ref_sub_src = intersectHashMap(cand_sub_src, refNgramCounts[i][n]);
- // System.out.println("CAND&REF-SRC : " + cand_and_ref_sub_src);
-
- HashMap src_sub_cand = substractHashMap(srcNgramCounts[i][n], candNgramCounts[n],
- this.refsPerSen, this.refsPerSen);
- HashMap src_sub_ref = substractHashMap(srcNgramCounts[i][n], refNgramCounts[i][n],
- this.refsPerSen, 1);
- HashMap src_sub_cand_sub_ref = intersectHashMap(src_sub_cand, src_sub_ref, 1, 1);
- // System.out.println("SRC-REF-CAND : " + src_sub_cand_sub_ref);
-
- // System.out.println("DEBUG:" + Arrays.toString(stats));
- // System.out.println("REF-SRC: " + substractHashMap(refNgramCounts[i], srcNgramCounts[i][0],
- // (double)refsPerSen));
-
+ return stats;
+ }
+
++ @Override
+ public double score(int[] stats) {
+ if (stats.length != suffStatsCount) {
+ System.out.println("Mismatch between stats.length and suffStatsCount (" + stats.length
+ + " vs. " + suffStatsCount + ") in NewMetric.score(int[])");
+ System.exit(1);
+ }
+
+ double sc = 0.0;
+
+ for (int n = 1; n <= maxGramLength; ++n) {
+
+ int addCandCorrectNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.ADDBOTH.ordinal()];
+ int addCandTotalNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.ADDCAND.ordinal()];
+ int addRefTotalNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.ADDREF.ordinal()];
+
+ double prec_add_n = 0.0;
+ if (addCandTotalNgram > 0) {
+ prec_add_n = addCandCorrectNgram / (double) addCandTotalNgram;
+ }
+
+ double recall_add_n = 0.0;
+ if (addRefTotalNgram > 0) {
+ recall_add_n = addCandCorrectNgram / (double) addRefTotalNgram;
+ }
+
+ // System.out.println("\nDEBUG-SARI:" + addCandCorrectNgram + " " + addCandTotalNgram + " " +
+ // addRefTotalNgram);
+
+ double f1_add_n = meanHarmonic(prec_add_n, recall_add_n);
+
+ sc += weights[n] * f1_add_n;
+
+ int delCandCorrectNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.DELBOTH.ordinal()];
+ int delCandTotalNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.DELCAND.ordinal()];
- int delRefTotalNgram = stats[StatIndex.values().length * (n - 1)
- + StatIndex.DELREF.ordinal()];
+
+ double prec_del_n = 0.0;
+ if (delCandTotalNgram > 0) {
+ prec_del_n = delCandCorrectNgram / (double) delCandTotalNgram;
+ }
+
- double recall_del_n = 0.0;
- if (delRefTotalNgram > 0) {
- recall_del_n = delCandCorrectNgram / (double) delRefTotalNgram;
- }
-
+ // System.out.println("\nDEBUG-SARI:" + delCandCorrectNgram + " " + delRefTotalNgram);
-
- double f1_del_n = meanHarmonic(prec_del_n, recall_del_n);
-
+ // sc += weights[n] * f1_del_n;
+ sc += weights[n] * prec_del_n;
+
+ int keepCandCorrectNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPBOTH.ordinal()];
+ // int keepCandCorrectNgram2 = stats[StatIndex.values().length * (n - 1) +
+ // StatIndex.KEEPBOTH2.ordinal()];
+ int keepCandTotalNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPCAND.ordinal()];
+ int keepRefTotalNgram = stats[StatIndex.values().length * (n - 1)
+ + StatIndex.KEEPREF.ordinal()];
+
+ double prec_keep_n = 0.0;
+ if (keepCandTotalNgram > 0) {
+ prec_keep_n = keepCandCorrectNgram / (double) (1000000 * keepCandTotalNgram);
+ }
+
+ double recall_keep_n = 0.0;
+ if (keepRefTotalNgram > 0) {
+ recall_keep_n = keepCandTotalNgram / (double) keepRefTotalNgram;
+ }
+
+ // System.out.println("\nDEBUG-SARI-KEEP: " + n + " " + keepCandCorrectNgram + " " +
+ // keepCandTotalNgram + " " + keepRefTotalNgram);
+
+ double f1_keep_n = meanHarmonic(prec_keep_n, recall_keep_n);
+
+ sc += weights[n] * f1_keep_n;
+
+ // System.out.println("\nDEBUG-SARI: " + n + " " + prec_add_n + " " + recall_add_n + " " +
+ // prec_del_n + " " + recall_del_n + " " + prec_keep_n + " " + recall_keep_n);
+
+ // System.out.println("\nDEBUG-SARI-KEEP: " + n + " " + keepCandCorrectNgram + " " +
+ // keepCandTotalNgram + " " + keepRefTotalNgram);
+ }
+
+ sc = sc / 3.0;
+ //
+ //
+ // set sc here!
+ //
+ //
+
+ return sc;
+ }
+
+ public double meanHarmonic(double precision, double recall) {
+
+ if (precision > 0 && recall > 0) {
+ return (2.0 * precision * recall) / (precision + recall);
+ }
+ return 0.0;
+ }
+
+ public void loadSources(String filepath) throws IOException {
+ srcSentences = new String[numSentences];
+ // BufferedReader br = new BufferedReader(new FileReader(filepath));
+ InputStream inStream = new FileInputStream(new File(filepath));
+ BufferedReader br = new BufferedReader(new InputStreamReader(inStream, "utf8"));
+
+ String line;
+ int i = 0;
+ while (i < numSentences && (line = br.readLine()) != null) {
+ srcSentences[i] = line.trim();
+ i++;
+ }
+ br.close();
+ }
+
+ public double sumHashMapByDoubleValues(HashMap<String, Double> counter) {
+ double sumcounts = 0;
+
+ for (Map.Entry<String, Double> e : counter.entrySet()) {
- sumcounts += (double) e.getValue();
++ sumcounts += e.getValue();
+ }
+
+ return sumcounts;
+ }
+
+ public int sumHashMapByValues(HashMap<String, Integer> counter) {
+ int sumcounts = 0;
+
+ for (Map.Entry<String, Integer> e : counter.entrySet()) {
- sumcounts += (int) e.getValue();
++ sumcounts += e.getValue();
+ }
+
+ return sumcounts;
+ }
+
+ public HashMap<String, Integer> substractHashMap(HashMap<String, Integer> counter1,
+ HashMap<String, Integer> counter2) {
+ HashMap<String, Integer> newcounter = new HashMap<String, Integer>();
+
+ for (Map.Entry<String, Integer> e : counter1.entrySet()) {
+ String ngram = e.getKey();
- int count1 = e.getValue();
+ int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
+ if (count2 == 0) {
+ newcounter.put(ngram, 1);
+ }
+ }
+
+ return newcounter;
+ }
+
+ // HashMap result = counter1*ratio1 - counter2*ratio2
+ public HashMap<String, Integer> substractHashMap(HashMap<String, Integer> counter1,
+ HashMap<String, Integer> counter2, int ratio1, int ratio2) {
+ HashMap<String, Integer> newcounter = new HashMap<String, Integer>();
+
+ for (Map.Entry<String, Integer> e : counter1.entrySet()) {
+ String ngram = e.getKey();
+ int count1 = e.getValue();
+ int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
+ int newcount = count1 * ratio1 - count2 * ratio2;
+ if (newcount > 0) {
+ newcounter.put(ngram, newcount);
+ }
+ }
+
+ return newcounter;
+ }
+
+ public HashMap<String, Double> divideHashMap(HashMap<String, Integer> counter1,
+ HashMap<String, Integer> counter2) {
+ HashMap<String, Double> newcounter = new HashMap<String, Double>();
+
+ for (Map.Entry<String, Integer> e : counter1.entrySet()) {
+ String ngram = e.getKey();
+ int count1 = e.getValue();
+ int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
+ if (count2 != 0) {
+ newcounter.put(ngram, (double) count1 / (double) count2);
+ }
+ }
+
+ return newcounter;
+ }
+
+ public HashMap<String, Integer> intersectHashMap(HashMap<String, Integer> counter1,
+ HashMap<String, Integer> counter2) {
+ HashMap<String, Integer> newcounter = new HashMap<String, Integer>();
+
+ for (Map.Entry<String, Integer> e : counter1.entrySet()) {
+ String ngram = e.getKey();
- int count1 = e.getValue();
+ int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
+ if (count2 > 0) {
+ newcounter.put(ngram, 1);
+ }
+ }
+
+ return newcounter;
+ }
+
+ // HashMap result = (counter1*ratio1) & (counter2*ratio2)
+ public HashMap<String, Integer> intersectHashMap(HashMap<String, Integer> counter1,
+ HashMap<String, Integer> counter2, int ratio1, int ratio2) {
+ HashMap<String, Integer> newcounter = new HashMap<String, Integer>();
+
+ for (Map.Entry<String, Integer> e : counter1.entrySet()) {
+ String ngram = e.getKey();
+ int count1 = e.getValue();
+ int count2 = counter2.containsKey(ngram) ? counter2.get(ngram) : 0;
+ int newcount = Math.min(count1 * ratio1, count2 * ratio2);
+ if (newcount > 0) {
+ newcounter.put(ngram, newcount);
+ }
+ }
+
+ return newcounter;
+ }
+
+ protected int wordCount(String cand_str) {
+ if (!cand_str.equals("")) {
+ return cand_str.split("\\s+").length;
+ } else {
+ return 0;
+ }
+ }
+
+ public HashMap<String, Integer>[] getNgramCountsArray(String cand_str) {
+ if (!cand_str.equals("")) {
+ return getNgramCountsArray(cand_str.split("\\s+"));
+ } else {
+ return getNgramCountsArray(new String[0]);
+ }
+ }
+
+ public HashMap<String, Integer>[] getNgramCountsArray(String[] words) {
+ @SuppressWarnings("unchecked")
+ HashMap<String, Integer>[] ngramCountsArray = new HashMap[1 + maxGramLength];
+ ngramCountsArray[0] = null;
+ for (int n = 1; n <= maxGramLength; ++n) {
+ ngramCountsArray[n] = new HashMap<String, Integer>();
+ }
+
+ int len = words.length;
+ String gram;
+ int st = 0;
+
+ for (; st <= len - maxGramLength; ++st) {
+
+ gram = words[st];
+ if (ngramCountsArray[1].containsKey(gram)) {
+ int oldCount = ngramCountsArray[1].get(gram);
+ ngramCountsArray[1].put(gram, oldCount + 1);
+ } else {
+ ngramCountsArray[1].put(gram, 1);
+ }
+
+ for (int n = 2; n <= maxGramLength; ++n) {
+ gram = gram + " " + words[st + n - 1];
+ if (ngramCountsArray[n].containsKey(gram)) {
+ int oldCount = ngramCountsArray[n].get(gram);
+ ngramCountsArray[n].put(gram, oldCount + 1);
+ } else {
+ ngramCountsArray[n].put(gram, 1);
+ }
+ } // for (n)
+
+ } // for (st)
+
+ // now st is either len-maxGramLength+1 or zero (if above loop never entered, which
+ // happens with sentences that have fewer than maxGramLength words)
+
+ for (; st < len; ++st) {
+
+ gram = words[st];
+ if (ngramCountsArray[1].containsKey(gram)) {
+ int oldCount = ngramCountsArray[1].get(gram);
+ ngramCountsArray[1].put(gram, oldCount + 1);
+ } else {
+ ngramCountsArray[1].put(gram, 1);
+ }
+
+ int n = 2;
+ for (int fin = st + 1; fin < len; ++fin) {
+ gram = gram + " " + words[st + n - 1];
+
+ if (ngramCountsArray[n].containsKey(gram)) {
+ int oldCount = ngramCountsArray[n].get(gram);
+ ngramCountsArray[n].put(gram, oldCount + 1);
+ } else {
+ ngramCountsArray[n].put(gram, 1);
+ }
+ ++n;
+ } // for (fin)
+
+ } // for (st)
+
+ return ngramCountsArray;
+
+ }
+
+ public HashMap<String, Integer> getNgramCountsAll(String cand_str) {
+ if (!cand_str.equals("")) {
+ return getNgramCountsAll(cand_str.split("\\s+"));
+ } else {
+ return getNgramCountsAll(new String[0]);
+ }
+ }
+
+ public HashMap<String, Integer> getNgramCountsAll(String[] words) {
+ HashMap<String, Integer> ngramCountsAll = new HashMap<String, Integer>();
+
+ int len = words.length;
+ String gram;
+ int st = 0;
+
+ for (; st <= len - maxGramLength; ++st) {
+
+ gram = words[st];
+ if (ngramCountsAll.containsKey(gram)) {
+ int oldCount = ngramCountsAll.get(gram);
+ ngramCountsAll.put(gram, oldCount + 1);
+ } else {
+ ngramCountsAll.put(gram, 1);
+ }
+
+ for (int n = 2; n <= maxGramLength; ++n) {
+ gram = gram + " " + words[st + n - 1];
+ if (ngramCountsAll.containsKey(gram)) {
+ int oldCount = ngramCountsAll.get(gram);
+ ngramCountsAll.put(gram, oldCount + 1);
+ } else {
+ ngramCountsAll.put(gram, 1);
+ }
+ } // for (n)
+
+ } // for (st)
+
+ // now st is either len-maxGramLength+1 or zero (if above loop never entered, which
+ // happens with sentences that have fewer than maxGramLength words)
+
+ for (; st < len; ++st) {
+
+ gram = words[st];
+ if (ngramCountsAll.containsKey(gram)) {
+ int oldCount = ngramCountsAll.get(gram);
+ ngramCountsAll.put(gram, oldCount + 1);
+ } else {
+ ngramCountsAll.put(gram, 1);
+ }
+
+ int n = 2;
+ for (int fin = st + 1; fin < len; ++fin) {
+ gram = gram + " " + words[st + n - 1];
+
+ if (ngramCountsAll.containsKey(gram)) {
+ int oldCount = ngramCountsAll.get(gram);
+ ngramCountsAll.put(gram, oldCount + 1);
+ } else {
+ ngramCountsAll.put(gram, 1);
+ }
+ ++n;
+ } // for (fin)
+
+ } // for (st)
+
+ return ngramCountsAll;
+
+ }
+
++ @Override
+ public void printDetailedScore_fromStats(int[] stats, boolean oneLiner) {
+ System.out.println(metricName + " = " + score(stats));
+
+ // for (Map.Entry<String, Integer> entry : refNgramCounts.) {
+ // System.out.println(entry.getKey()+" : "+ entry.getValue());
+ // }
+ //
+ //
+ // optional (for debugging purposes)
+ //
+ //
+ }
+
+ private enum StatIndex {
+ KEEPBOTH, KEEPCAND, KEEPREF, DELBOTH, DELCAND, DELREF, ADDBOTH, ADDCAND, ADDREF, KEEPBOTH2
+ };
+
+}
[11/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/mira/MIRACore.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/mira/MIRACore.java
index 78b815a,0000000..e0354b9
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/mira/MIRACore.java
+++ b/joshua-core/src/main/java/org/apache/joshua/mira/MIRACore.java
@@@ -1,3112 -1,0 +1,2921 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.mira;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.TreeSet;
+import java.util.Vector;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
++import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.metrics.EvaluationMetric;
+import org.apache.joshua.util.StreamGobbler;
- import org.apache.joshua.corpus.Vocabulary;
++import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This code was originally written by Yuan Cao, who copied the MERT code to produce this file.
+ */
+
+public class MIRACore {
+
+ private static final Logger LOG = LoggerFactory.getLogger(MIRACore.class);
+
+ private final JoshuaConfiguration joshuaConfiguration;
+ private TreeSet<Integer>[] indicesOfInterest_all;
+
+ private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+ private final static double epsilon = 1.0 / 1000000;
+
- private int progress;
-
+ private int verbosity; // anything of priority <= verbosity will be printed
+ // (lower value for priority means more important)
+
+ private Random randGen;
- private int generatedRands;
+
+ private int numSentences;
+ // number of sentences in the dev set
+ // (aka the "MERT training" set)
+
+ private int numDocuments;
+ // number of documents in the dev set
+ // this should be 1, unless doing doc-level optimization
+
+ private int[] docOfSentence;
+ // docOfSentence[i] stores which document contains the i'th sentence.
+ // docOfSentence is 0-indexed, as are the documents (i.e. first doc is indexed 0)
+
+ private int[] docSubsetInfo;
+ // stores information regarding which subset of the documents are evaluated
+ // [0]: method (0-6)
+ // [1]: first (1-indexed)
+ // [2]: last (1-indexed)
+ // [3]: size
+ // [4]: center
+ // [5]: arg1
+ // [6]: arg2
+ // [1-6] are 0 for method 0, [6] is 0 for methods 1-4 as well
+ // only [1] and [2] are needed for optimization. The rest are only needed for an output message.
+
+ private int refsPerSen;
+ // number of reference translations per sentence
+
+ private int textNormMethod;
+ // 0: no normalization, 1: "NIST-style" tokenization, and also rejoin 'm, 're, *'s, 've, 'll, 'd,
+ // and n't,
+ // 2: apply 1 and also rejoin dashes between letters, 3: apply 1 and also drop non-ASCII
+ // characters
+ // 4: apply 1+2+3
+
+ private int numParams;
+ // total number of firing features
+ // this number may increase overtime as new n-best lists are decoded
+ // initially it is equal to the # of params in the parameter config file
+ private int numParamsOld;
+ // number of features before observing the new features fired in the current iteration
+
+ private double[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ /* *********************************************************** */
+ /* NOTE: indexing starts at 1 in the following few arrays: */
+ /* *********************************************************** */
+
+ // private double[] lambda;
+ private ArrayList<Double> lambda = new ArrayList<Double>();
+ // the current weight vector. NOTE: indexing starts at 1.
+ private ArrayList<Double> bestLambda = new ArrayList<Double>();
+ // the best weight vector across all iterations
+
+ private boolean[] isOptimizable;
+ // isOptimizable[c] = true iff lambda[c] should be optimized
+
+ private double[] minRandValue;
+ private double[] maxRandValue;
+ // when choosing a random value for the lambda[c] parameter, it will be
+ // chosen from the [minRandValue[c],maxRandValue[c]] range.
+ // (*) minRandValue and maxRandValue must be real values, but not -Inf or +Inf
+
+ private double[] defaultLambda;
+ // "default" parameter values; simply the values read in the parameter file
+ // USED FOR NON-OPTIMIZABLE (FIXED) FEATURES
+
+ /* *********************************************************** */
+ /* *********************************************************** */
+
+ private Decoder myDecoder;
+ // COMMENT OUT if decoder is not Joshua
+
+ private String decoderCommand;
+ // the command that runs the decoder; read from decoderCommandFileName
+
+ private int decVerbosity;
+ // verbosity level for decoder output. If 0, decoder output is ignored.
+ // If 1, decoder output is printed.
+
+ private int validDecoderExitValue;
+ // return value from running the decoder command that indicates success
+
+ private int numOptThreads;
+ // number of threads to run things in parallel
+
+ private int saveInterFiles;
+ // 0: nothing, 1: only configs, 2: only n-bests, 3: both configs and n-bests
+
+ private int compressFiles;
+ // should MIRA gzip the large files? If 0, no compression takes place.
+ // If 1, compression is performed on: decoder output files, temp sents files,
+ // and temp feats files.
+
+ private int sizeOfNBest;
+ // size of N-best list generated by decoder at each iteration
+ // (aka simply N, but N is a bad variable name)
+
+ private long seed;
+ // seed used to create random number generators
+
+ private boolean randInit;
+ // if true, parameters are initialized randomly. If false, parameters
+ // are initialized using values from parameter file.
+
+ private int maxMERTIterations, minMERTIterations, prevMERTIterations;
+ // max: maximum number of MERT iterations
+ // min: minimum number of MERT iterations before an early MERT exit
+ // prev: number of previous MERT iterations from which to consider candidates (in addition to
+ // the candidates from the current iteration)
+
+ private double stopSigValue;
+ // early MERT exit if no weight changes by more than stopSigValue
+ // (but see minMERTIterations above and stopMinIts below)
+
+ private int stopMinIts;
+ // some early stopping criterion must be satisfied in stopMinIts *consecutive* iterations
+ // before an early exit (but see minMERTIterations above)
+
+ private boolean oneModificationPerIteration;
+ // if true, each MERT iteration performs at most one parameter modification.
+ // If false, a new MERT iteration starts (i.e. a new N-best list is
+ // generated) only after the previous iteration reaches a local maximum.
+
+ private String metricName;
+ // name of evaluation metric optimized by MERT
+
+ private String metricName_display;
+ // name of evaluation metric optimized by MERT, possibly with "doc-level " prefixed
+
+ private String[] metricOptions;
+ // options for the evaluation metric (e.g. for BLEU, maxGramLength and effLengthMethod)
+
+ private EvaluationMetric evalMetric;
+ // the evaluation metric used by MERT
+
+ private int suffStatsCount;
+ // number of sufficient statistics for the evaluation metric
+
+ private String tmpDirPrefix;
+ // prefix for the MIRA.temp.* files
+
+ private boolean passIterationToDecoder;
+ // should the iteration number be passed as an argument to decoderCommandFileName?
+
+ // used by mira
+ private boolean needShuffle = true; // shuffle the training sentences or not
+ private boolean needAvg = true; // average the weihgts or not?
+ private boolean runPercep = false; // run perceptron instead of mira
+ private boolean usePseudoBleu = true; // need to use pseudo corpus to compute bleu?
+ private boolean returnBest = false; // return the best weight during tuning
+ private boolean needScale = true; // need scaling?
- private String trainingMode;
++
+ private int oraSelectMode = 1;
+ private int predSelectMode = 1;
+ private int miraIter = 1;
+ private int batchSize = 1;
+ private double C = 0.01; // relaxation coefficient
+ private double R = 0.99; // corpus decay when pseudo corpus is used for bleu computation
+ // private double sentForScale = 0.15; //percentage of sentences for scale factor estimation
+ private double scoreRatio = 5.0; // sclale so that model_score/metric_score = scoreratio
+ private double prevMetricScore = 0; // final metric score of the previous iteration, used only
+ // when returnBest = true
+
+ private String dirPrefix; // where are all these files located?
+ private String paramsFileName, docInfoFileName, finalLambdaFileName;
+ private String sourceFileName, refFileName, decoderOutFileName;
+ private String decoderConfigFileName, decoderCommandFileName;
+ private String fakeFileNameTemplate, fakeFileNamePrefix, fakeFileNameSuffix;
+
+ // e.g. output.it[1-x].someOldRun would be specified as:
+ // output.it?.someOldRun
+ // and we'd have prefix = "output.it" and suffix = ".sameOldRun"
+
+ // private int useDisk;
+
+ public MIRACore(JoshuaConfiguration joshuaConfiguration) {
+ this.joshuaConfiguration = joshuaConfiguration;
+ }
+
- public MIRACore(String[] args, JoshuaConfiguration joshuaConfiguration) {
++ public MIRACore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(args);
+ initialize(0);
+ }
+
- public MIRACore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
++ public MIRACore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(cfgFileToArgsArray(configFileName));
+ initialize(0);
+ }
+
- private void initialize(int randsToSkip) {
++ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
+ println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
+
+ randGen = new Random(seed);
+ for (int r = 1; r <= randsToSkip; ++r) {
+ randGen.nextDouble();
+ }
- generatedRands = randsToSkip;
+
+ if (randsToSkip == 0) {
+ println("----------------------------------------------------", 1);
+ println("Initializing...", 1);
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ println("Random number generator initialized using seed: " + seed, 1);
+ println("", 1);
+ }
+
+ // count the total num of sentences to be decoded, reffilename is the combined reference file
+ // name(auto generated)
- numSentences = countLines(refFileName) / refsPerSen;
++ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
+
+ // ??
+ processDocInfo();
+ // sets numDocuments and docOfSentence[]
+
+ if (numDocuments > 1)
+ metricName_display = "doc-level " + metricName;
+
+ // ??
+ set_docSubsetInfo(docSubsetInfo);
+
+ // count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
++ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
+ numParamsOld = numParams;
+
+ // read parameter config file
+ try {
+ // read dense parameter names
+ BufferedReader inFile_names = new BufferedReader(new FileReader(paramsFileName));
+
+ for (int c = 1; c <= numParams; ++c) {
+ String line = "";
+ while (line != null && line.length() == 0) { // skip empty lines
+ line = inFile_names.readLine();
+ }
+
+ // save feature names
+ String paramName = (line.substring(0, line.indexOf("|||"))).trim();
+ Vocabulary.id(paramName);
+ // System.err.println(String.format("VOCAB(%s) = %d", paramName, id));
+ }
+
+ inFile_names.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // the parameter file contains one line per parameter
+ // and one line for the normalization method
+ // indexing starts at 1 in these arrays
+ for (int p = 0; p <= numParams; ++p)
+ lambda.add(new Double(0));
+ bestLambda.add(new Double(0));
+ // why only lambda is a list? because the size of lambda
+ // may increase over time, but other arrays are specified in
+ // the param config file, only used for initialization
+ isOptimizable = new boolean[1 + numParams];
+ minRandValue = new double[1 + numParams];
+ maxRandValue = new double[1 + numParams];
+ defaultLambda = new double[1 + numParams];
+ normalizationOptions = new double[3];
+
+ // read initial param values
+ processParamFile();
+ // sets the arrays declared just above
+
+ // SentenceInfo.createV(); // uncomment ONLY IF using vocabulary implementation of SentenceInfo
+
+ String[][] refSentences = new String[numSentences][refsPerSen];
+
+ try {
+
+ // read in reference sentences
+ InputStream inStream_refs = new FileInputStream(new File(refFileName));
+ BufferedReader inFile_refs = new BufferedReader(new InputStreamReader(inStream_refs, "utf8"));
+
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // read the rth reference translation for the ith sentence
+ refSentences[i][r] = inFile_refs.readLine();
+ }
+ }
+
+ inFile_refs.close();
+
+ // normalize reference sentences
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // normalize the rth reference translation for the ith sentence
+ refSentences[i][r] = normalize(refSentences[i][r], textNormMethod);
+ }
+ }
+
+ // read in decoder command, if any
+ decoderCommand = null;
+ if (decoderCommandFileName != null) {
+ if (fileExists(decoderCommandFileName)) {
+ BufferedReader inFile_comm = new BufferedReader(new FileReader(decoderCommandFileName));
+ decoderCommand = inFile_comm.readLine(); // READ IN DECODE COMMAND
+ inFile_comm.close();
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // set static data members for the EvaluationMetric class
+ EvaluationMetric.set_numSentences(numSentences);
+ EvaluationMetric.set_numDocuments(numDocuments);
+ EvaluationMetric.set_refsPerSen(refsPerSen);
+ EvaluationMetric.set_refSentences(refSentences);
+ EvaluationMetric.set_tmpDirPrefix(tmpDirPrefix);
+
+ evalMetric = EvaluationMetric.getMetric(metricName, metricOptions);
+ // used only if returnBest = true
+ prevMetricScore = evalMetric.getToBeMinimized() ? PosInf : NegInf;
+
+ // length of sufficient statistics
+ // for bleu: suffstatscount=8 (2*ngram+2)
+ suffStatsCount = evalMetric.get_suffStatsCount();
+
+ // set static data members for the IntermediateOptimizer class
+ /*
+ * IntermediateOptimizer.set_MERTparams(numSentences, numDocuments, docOfSentence,
+ * docSubsetInfo, numParams, normalizationOptions, isOptimizable oneModificationPerIteration,
+ * evalMetric, tmpDirPrefix, verbosity);
+ */
+
+ // print info
+ if (randsToSkip == 0) { // i.e. first iteration
+ println("Number of sentences: " + numSentences, 1);
+ println("Number of documents: " + numDocuments, 1);
+ println("Optimizing " + metricName_display, 1);
+
+ /*
+ * print("docSubsetInfo: {", 1); for (int f = 0; f < 6; ++f) print(docSubsetInfo[f] + ", ",
+ * 1); println(docSubsetInfo[6] + "}", 1);
+ */
+
+ println("Number of initial features: " + numParams, 1);
+ print("Initial feature names: {", 1);
+
+ for (int c = 1; c <= numParams; ++c)
+ print("\"" + Vocabulary.word(c) + "\"", 1);
+ println("}", 1);
+ println("", 1);
+
+ // TODO just print the correct info
+ println("c Default value\tOptimizable?\tRand. val. range", 1);
+
+ for (int c = 1; c <= numParams; ++c) {
+ print(c + " " + f4.format(lambda.get(c).doubleValue()) + "\t\t", 1);
+
+ if (!isOptimizable[c]) {
+ println(" No", 1);
+ } else {
+ print(" Yes\t\t", 1);
+ print(" [" + minRandValue[c] + "," + maxRandValue[c] + "]", 1);
+ println("", 1);
+ }
+ }
+
+ println("", 1);
+ print("Weight vector normalization method: ", 1);
+ if (normalizationOptions[0] == 0) {
+ println("none.", 1);
+ } else if (normalizationOptions[0] == 1) {
+ println(
+ "weights will be scaled so that the \""
+ + Vocabulary.word((int) normalizationOptions[2])
+ + "\" weight has an absolute value of " + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 2) {
+ println("weights will be scaled so that the maximum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 3) {
+ println("weights will be scaled so that the minimum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 4) {
+ println("weights will be scaled so that the L-" + normalizationOptions[1] + " norm is "
+ + normalizationOptions[2] + ".", 1);
+ }
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // rename original config file so it doesn't get overwritten
+ // (original name will be restored in finish())
+ renameFile(decoderConfigFileName, decoderConfigFileName + ".MIRA.orig");
+ } // if (randsToSkip == 0)
+
+ // by default, load joshua decoder
+ if (decoderCommand == null && fakeFileNameTemplate == null) {
+ println("Loading Joshua decoder...", 1);
+ myDecoder = new Decoder(joshuaConfiguration);
+ println("...finished loading @ " + (new Date()), 1);
+ println("");
+ } else {
+ myDecoder = null;
+ }
+
+ @SuppressWarnings("unchecked")
+ TreeSet<Integer>[] temp_TSA = new TreeSet[numSentences];
+ indicesOfInterest_all = temp_TSA;
+
+ for (int i = 0; i < numSentences; ++i) {
+ indicesOfInterest_all[i] = new TreeSet<Integer>();
+ }
+ } // void initialize(...)
+
+ // -------------------------
+
+ public void run_MIRA() {
+ run_MIRA(minMERTIterations, maxMERTIterations, prevMERTIterations);
+ }
+
+ public void run_MIRA(int minIts, int maxIts, int prevIts) {
+ // FIRST, CLEAN ALL PREVIOUS TEMP FILES
+ String dir;
+ int k = tmpDirPrefix.lastIndexOf("/");
+ if (k >= 0) {
+ dir = tmpDirPrefix.substring(0, k + 1);
+ } else {
+ dir = "./";
+ }
+ String files;
+ File folder = new File(dir);
+
+ if (folder.exists()) {
+ File[] listOfFiles = folder.listFiles();
+
+ for (int i = 0; i < listOfFiles.length; i++) {
+ if (listOfFiles[i].isFile()) {
+ files = listOfFiles[i].getName();
+ if (files.startsWith("MIRA.temp")) {
+ deleteFile(files);
+ }
+ }
+ }
+ }
+
+ println("----------------------------------------------------", 1);
+ println("MIRA run started @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // if no default lambda is provided
+ if (randInit) {
+ println("Initializing lambda[] randomly.", 1);
+ // initialize optimizable parameters randomly (sampling uniformly from
+ // that parameter's random value range)
+ lambda = randomLambda();
+ }
+
+ println("Initial lambda[]: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ int[] maxIndex = new int[numSentences];
+
+ // HashMap<Integer,int[]>[] suffStats_array = new HashMap[numSentences];
+ // suffStats_array[i] maps candidates of interest for sentence i to an array
+ // storing the sufficient statistics for that candidate
+
+ int earlyStop = 0;
+ // number of consecutive iteration an early stopping criterion was satisfied
+
+ for (int iteration = 1;; ++iteration) {
+
+ // what does "A" contain?
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+ double[] A = run_single_iteration(iteration, minIts, maxIts, prevIts, earlyStop, maxIndex);
+ if (A != null) {
+ earlyStop = (int) A[1];
+ if (A[2] == 1)
+ break;
+ } else {
+ break;
+ }
+
+ } // for (iteration)
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("MIRA run ended @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+ if (!returnBest)
+ println("FINAL lambda: " + lambdaToString(lambda), 1);
+ // + " (" + metricName_display + ": " + FINAL_score + ")",1);
+ else
+ println("BEST lambda: " + lambdaToString(lambda), 1);
+
+ // delete intermediate .temp.*.it* decoder output files
+ for (int iteration = 1; iteration <= maxIts; ++iteration) {
+ if (compressFiles == 1) {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration + ".gz");
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration + ".gz");
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz");
+ }
+ } else {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration);
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration);
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+ }
+ }
+ } // void run_MIRA(int maxIts)
+
+ // this is the key function!
+ @SuppressWarnings("unchecked")
+ public double[] run_single_iteration(int iteration, int minIts, int maxIts, int prevIts,
+ int earlyStop, int[] maxIndex) {
+ double FINAL_score = 0;
+
+ double[] retA = new double[3];
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+
+ boolean done = false;
+ retA[2] = 1; // will only be made 0 if we don't break from the following loop
+
+ // save feats and stats for all candidates(old & new)
+ HashMap<String, String>[] feat_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ feat_hash[i] = new HashMap<String, String>();
+
+ HashMap<String, String>[] stats_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ stats_hash[i] = new HashMap<String, String>();
+
+ while (!done) { // NOTE: this "loop" will only be carried out once
+ println("--- Starting MIRA iteration #" + iteration + " @ " + (new Date()) + " ---", 1);
+
+ // printMemoryUsage();
+
+ /******************************/
+ // CREATE DECODER CONFIG FILE //
+ /******************************/
+
+ createConfigFile(lambda, decoderConfigFileName, decoderConfigFileName + ".MIRA.orig");
+ // i.e. use the original config file as a template
+
+ /***************/
+ // RUN DECODER //
+ /***************/
+
+ if (iteration == 1) {
+ println("Decoding using initial weight vector " + lambdaToString(lambda), 1);
+ } else {
+ println("Redecoding using weight vector " + lambdaToString(lambda), 1);
+ }
+
+ // generate the n-best file after decoding
+ String[] decRunResult = run_decoder(iteration); // iteration passed in case fake decoder will
+ // be used
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ if (!decRunResult[1].equals("2")) {
+ println("...finished decoding @ " + (new Date()), 1);
+ }
+
+ checkFile(decRunResult[0]);
+
+ /************* END OF DECODING **************/
+
+ println("Producing temp files for iteration " + iteration, 3);
+
+ produceTempFiles(decRunResult[0], iteration);
+
+ // save intermedidate output files
+ // save joshua.config.mira.it*
+ if (saveInterFiles == 1 || saveInterFiles == 3) { // make copy of intermediate config file
+ if (!copyFile(decoderConfigFileName, decoderConfigFileName + ".MIRA.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder config file (to create"
+ + decoderConfigFileName + ".MIRA.it" + iteration + ") was unsuccessful!", 1);
+ }
+ }
+
+ // save output.nest.MIRA.it*
+ if (saveInterFiles == 2 || saveInterFiles == 3) { // make copy of intermediate decoder output
+ // file...
+
+ if (!decRunResult[1].equals("2")) { // ...but only if no fake decoder
+ if (!decRunResult[0].endsWith(".gz")) {
+ if (!copyFile(decRunResult[0], decRunResult[0] + ".MIRA.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder output file (to create"
+ + decRunResult[0] + ".MIRA.it" + iteration + ") was unsuccessful!", 1);
+ }
+ } else {
+ String prefix = decRunResult[0].substring(0, decRunResult[0].length() - 3);
+ if (!copyFile(prefix + ".gz", prefix + ".MIRA.it" + iteration + ".gz")) {
+ println("Warning: attempt to make copy of decoder output file (to create" + prefix
+ + ".MIRA.it" + iteration + ".gz" + ") was unsuccessful!", 1);
+ }
+ }
+
+ if (compressFiles == 1 && !decRunResult[0].endsWith(".gz")) {
+ gzipFile(decRunResult[0] + ".MIRA.it" + iteration);
+ }
+ } // if (!fake)
+ }
+
+ // ------------- end of saving .mira.it* files ---------------
+
+ int[] candCount = new int[numSentences];
+ int[] lastUsedIndex = new int[numSentences];
+
+ ConcurrentHashMap<Integer, int[]>[] suffStats_array = new ConcurrentHashMap[numSentences];
+ for (int i = 0; i < numSentences; ++i) {
+ candCount[i] = 0;
+ lastUsedIndex[i] = -1;
+ // suffStats_array[i].clear();
+ suffStats_array[i] = new ConcurrentHashMap<Integer, int[]>();
+ }
+
+ // initLambda[0] is not used!
+ double[] initialLambda = new double[1 + numParams];
+ for (int i = 1; i <= numParams; ++i)
+ initialLambda[i] = lambda.get(i);
+
+ // the "score" in initialScore refers to that
+ // assigned by the evaluation metric)
+
+ // you may consider all candidates from iter 1, or from iter (iteration-prevIts) to current
+ // iteration
+ int firstIt = Math.max(1, iteration - prevIts);
+ // i.e. only process candidates from the current iteration and candidates
+ // from up to prevIts previous iterations.
+ println("Reading candidate translations from iterations " + firstIt + "-" + iteration, 1);
+ println("(and computing " + metricName
+ + " sufficient statistics for previously unseen candidates)", 1);
+ print(" Progress: ");
+
+ int[] newCandidatesAdded = new int[1 + iteration];
+ for (int it = 1; it <= iteration; ++it)
+ newCandidatesAdded[it] = 0;
+
+ try {
+ // read temp files from all past iterations
+ // 3 types of temp files:
+ // 1. output hypo at iter i
+ // 2. feature value of each hypo at iter i
+ // 3. suff stats of each hypo at iter i
+
+ // each inFile corresponds to the output of an iteration
+ // (index 0 is not used; no corresponding index for the current iteration)
+ BufferedReader[] inFile_sents = new BufferedReader[iteration];
+ BufferedReader[] inFile_feats = new BufferedReader[iteration];
+ BufferedReader[] inFile_stats = new BufferedReader[iteration];
+
+ // temp file(array) from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ InputStream inStream_sents, inStream_feats, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_feats = new FileInputStream(tmpDirPrefix + "temp.feats.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_feats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.feats.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_feats[it] = new BufferedReader(new InputStreamReader(inStream_feats, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ InputStream inStream_sentsCurrIt, inStream_featsCurrIt, inStream_statsCurrIt;
+ // temp file for current iteration!
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ inStream_featsCurrIt = new FileInputStream(tmpDirPrefix + "temp.feats.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ inStream_featsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.feats.it" + iteration + ".gz"));
+ }
+
+ BufferedReader inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_sentsCurrIt, "utf8"));
+ BufferedReader inFile_featsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_featsCurrIt, "utf8"));
+
+ BufferedReader inFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below
+ // is set to true
+ PrintWriter outFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below is
+ // set to false
+
+ // just to check if temp.stat.it.iteration exists
+ boolean statsCurrIt_exists = false;
+
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration)) {
+ inStream_statsCurrIt = new FileInputStream(tmpDirPrefix + "temp.stats.it" + iteration);
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration, tmpDirPrefix + "temp.stats.it"
+ + iteration + ".copy");
+ } else if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".gz")) {
+ inStream_statsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.stats.it" + iteration + ".gz"));
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz", tmpDirPrefix
+ + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ outFile_statsCurrIt = new PrintWriter(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // output the 4^th temp file: *.temp.stats.merged
+ PrintWriter outFile_statsMerged = new PrintWriter(tmpDirPrefix + "temp.stats.merged");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+ PrintWriter outFile_statsMergedKnown = new PrintWriter(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+
+ // output the 5^th 6^th temp file, but will be deleted at the end of the function
+ FileOutputStream outStream_unknownCands = new FileOutputStream(tmpDirPrefix
+ + "temp.currIt.unknownCands", false);
+ OutputStreamWriter outStreamWriter_unknownCands = new OutputStreamWriter(
+ outStream_unknownCands, "utf8");
+ BufferedWriter outFile_unknownCands = new BufferedWriter(outStreamWriter_unknownCands);
+
+ PrintWriter outFile_unknownIndices = new PrintWriter(tmpDirPrefix
+ + "temp.currIt.unknownIndices");
+
+ String sents_str, feats_str, stats_str;
+
+ // BUG: this assumes a candidate string cannot be produced for two
+ // different source sentences, which is not necessarily true
+ // (It's not actually a bug, but only because existingCandStats gets
+ // cleared before moving to the next source sentence.)
+ // FIX: should be made an array, indexed by i
+ HashMap<String, String> existingCandStats = new HashMap<String, String>();
+ // VERY IMPORTANT:
+ // A CANDIDATE X MAY APPEARED IN ITER 1, ITER 3
+ // BUT IF THE USER SPECIFIED TO CONSIDER ITERATIONS FROM ONLY ITER 2, THEN
+ // X IS NOT A "REPEATED" CANDIDATE IN ITER 3. THEREFORE WE WANT TO KEEP THE
+ // SUFF STATS FOR EACH CANDIDATE(TO SAVE COMPUTATION IN THE FUTURE)
+
+ // Stores precalculated sufficient statistics for candidates, in case
+ // the same candidate is seen again. (SS stored as a String.)
+ // Q: Why do we care? If we see the same candidate again, aren't we going
+ // to ignore it? So, why do we care about the SS of this repeat candidate?
+ // A: A "repeat" candidate may not be a repeat candidate in later
+ // iterations if the user specifies a value for prevMERTIterations
+ // that causes MERT to skip candidates from early iterations.
+
- double[] currFeatVal = new double[1 + numParams];
+ String[] featVal_str;
+
+ int totalCandidateCount = 0;
+
+ // new candidate size for each sentence
+ int[] sizeUnknown_currIt = new int[numSentences];
+
+ for (int i = 0; i < numSentences; ++i) {
+ // process candidates from previous iterations
+ // low efficiency? for each iteration, it reads in all previous iteration outputs
+ // therefore a lot of overlapping jobs
+ // this is an easy implementation to deal with the situation in which user only specified
+ // "previt" and hopes to consider only the previous previt
+ // iterations, then for each iteration the existing candadites will be different
+ for (int it = firstIt; it < iteration; ++it) {
+ // Why up to but *excluding* iteration?
+ // Because the last iteration is handled a little differently, since
+ // the SS must be calculated (and the corresponding file created),
+ // which is not true for previous iterations.
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // note that in all temp files, "||||||" is a separator between 2 n-best lists
+
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence, feature values,
+ // and sufficient statistics from the various temp files
+
+ // read one line of temp.sent, temp.feat, temp.stats from iteration it
+ sents_str = inFile_sents[it].readLine();
+ feats_str = inFile_feats[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1; // move on to the next n-best list
+ } else if (!existingCandStats.containsKey(sents_str)) // if this candidate does not
+ // exist
+ {
+ outFile_statsMergedKnown.println(stats_str);
+
+ // save feats & stats
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ // extract feature value
+ featVal_str = feats_str.split("\\s+");
+
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+ newCandidatesAdded[it] += 1;
+
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ outFile_statsMergedKnown.println("||||||");
+
+ // ---------- end of processing previous iterations ----------
+ // ---------- now start processing new candidates ----------
+
+ // now process the candidates of the current iteration
+ // now determine the new candidates of the current iteration
+
+ /*
+ * remember: BufferedReader inFile_sentsCurrIt BufferedReader inFile_featsCurrIt
+ * PrintWriter outFile_statsCurrIt
+ */
+
+ String[] sentsCurrIt_currSrcSent = new String[sizeOfNBest + 1];
+
+ Vector<String> unknownCands_V = new Vector<String>();
+ // which candidates (of the i'th source sentence) have not been seen before
+ // this iteration?
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence,
+ // and store it in the sentsCurrIt_currSrcSent array
+
+ sents_str = inFile_sentsCurrIt.readLine(); // read one candidate from the current
+ // iteration
+ sentsCurrIt_currSrcSent[n] = sents_str; // Note: possibly "||||||"
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ unknownCands_V.add(sents_str); // NEW CANDIDATE FROM THIS ITERATION
+ writeLine(sents_str, outFile_unknownCands);
+ outFile_unknownIndices.println(i); // INDEX OF THE NEW CANDIDATES
+ newCandidatesAdded[iteration] += 1;
+ existingCandStats.put(sents_str, "U"); // i.e. unknown
+ // we add sents_str to avoid duplicate entries in unknownCands_V
+ }
+ } // for (n)
+
+ // only compute suff stats for new candidates
+ // now unknownCands_V has the candidates for which we need to calculate
+ // sufficient statistics (for the i'th source sentence)
+ int sizeUnknown = unknownCands_V.size();
+ sizeUnknown_currIt[i] = sizeUnknown;
+
+ existingCandStats.clear();
+
+ } // for (i) each sentence
+
+ // ---------- end of merging candidates stats from previous iterations
+ // and finding new candidates ------------
+
+ /*
+ * int[][] newSuffStats = null; if (!statsCurrIt_exists && sizeUnknown > 0) { newSuffStats =
+ * evalMetric.suffStats(unknownCands, indices); }
+ */
+
+ outFile_statsMergedKnown.close();
+ outFile_unknownCands.close();
+ outFile_unknownIndices.close();
+
+ // want to re-open all temp files and start from scratch again?
+ for (int it = firstIt; it < iteration; ++it) // previous iterations temp files
+ {
+ inFile_sents[it].close();
+ inFile_stats[it].close();
+
+ InputStream inStream_sents, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ inFile_sentsCurrIt.close();
+ // current iteration temp files
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ }
+ inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(inStream_sentsCurrIt, "utf8"));
+
+ // calculate SS for unseen candidates and write them to file
+ FileInputStream inStream_statsCurrIt_unknown = null;
+ BufferedReader inFile_statsCurrIt_unknown = null;
+
+ if (!statsCurrIt_exists && newCandidatesAdded[iteration] > 0) {
+ // create the file...
+ evalMetric.createSuffStatsFile(tmpDirPrefix + "temp.currIt.unknownCands", tmpDirPrefix
+ + "temp.currIt.unknownIndices", tmpDirPrefix + "temp.stats.unknown", sizeOfNBest);
+
+ // ...and open it
+ inStream_statsCurrIt_unknown = new FileInputStream(tmpDirPrefix + "temp.stats.unknown");
+ inFile_statsCurrIt_unknown = new BufferedReader(new InputStreamReader(
+ inStream_statsCurrIt_unknown, "utf8"));
+ }
+
+ // open mergedKnown file
+ // newly created by the big loop above
+ FileInputStream instream_statsMergedKnown = new FileInputStream(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ BufferedReader inFile_statsMergedKnown = new BufferedReader(new InputStreamReader(
+ instream_statsMergedKnown, "utf8"));
+
+ // num of features before observing new firing features from this iteration
+ numParamsOld = numParams;
+
+ for (int i = 0; i < numSentences; ++i) {
+ // reprocess candidates from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sents[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ existingCandStats.put(sents_str, stats_str);
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ // copy relevant portion from mergedKnown to the merged file
+ String line_mergedKnown = inFile_statsMergedKnown.readLine();
+ while (!line_mergedKnown.equals("||||||")) {
+ outFile_statsMerged.println(line_mergedKnown);
+ line_mergedKnown = inFile_statsMergedKnown.readLine();
+ }
+
+ int[] stats = new int[suffStatsCount];
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sentsCurrIt.readLine();
+ feats_str = inFile_featsCurrIt.readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+
+ if (!statsCurrIt_exists) {
+ stats_str = inFile_statsCurrIt_unknown.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+
+ outFile_statsCurrIt.println(stats_str);
+ } else {
+ stats_str = inFile_statsCurrIt.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+ }
+
+ outFile_statsMerged.println(stats_str);
+
+ // save feats & stats
+ // System.out.println(sents_str+" "+feats_str);
+
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ featVal_str = feats_str.split("\\s+");
+
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair : featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
+ int featId = Vocabulary.id(name);
+
+ // need to identify newly fired feats here
+ // in this case currFeatVal is not given the value
+ // of the new feat, since the corresponding weight is
+ // initialized as zero anyway
+ if (featId > numParams) {
+ ++numParams;
+ lambda.add(new Double(0));
+ }
+ }
+ }
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+
+ // newCandidatesAdded[iteration] += 1;
+ // moved to code above detecting new candidates
+ } else {
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else {
+ // write SS to outFile_statsCurrIt
+ stats_str = existingCandStats.get(sents_str);
+ outFile_statsCurrIt.println(stats_str);
+ }
+ }
+
+ } // for (n)
+
+ // now d = sizeUnknown_currIt[i] - 1
+
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else
+ outFile_statsCurrIt.println("||||||");
+
+ existingCandStats.clear();
+ totalCandidateCount += candCount[i];
+
+ // output sentence progress
+ if ((i + 1) % 500 == 0) {
+ print((i + 1) + "\n" + " ", 1);
+ } else if ((i + 1) % 100 == 0) {
+ print("+", 1);
+ } else if ((i + 1) % 25 == 0) {
+ print(".", 1);
+ }
+
+ } // for (i)
+
+ inFile_statsMergedKnown.close();
+ outFile_statsMerged.close();
+
+ // for testing
+ /*
+ * int total_sent = 0; for( int i=0; i<numSentences; i++ ) {
+ * System.out.println(feat_hash[i].size()+" "+candCount[i]); total_sent +=
+ * feat_hash[i].size(); feat_hash[i].clear(); }
+ * System.out.println("----------------total sent: "+total_sent); total_sent = 0; for( int
+ * i=0; i<numSentences; i++ ) { System.out.println(stats_hash[i].size()+" "+candCount[i]);
+ * total_sent += stats_hash[i].size(); stats_hash[i].clear(); }
+ * System.out.println("*****************total sent: "+total_sent);
+ */
+
+ println("", 1); // finish progress line
+
+ for (int it = firstIt; it < iteration; ++it) {
+ inFile_sents[it].close();
+ inFile_feats[it].close();
+ inFile_stats[it].close();
+ }
+
+ inFile_sentsCurrIt.close();
+ inFile_featsCurrIt.close();
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.close();
+ else
+ outFile_statsCurrIt.close();
+
+ if (compressFiles == 1 && !statsCurrIt_exists) {
+ gzipFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // clear temp files
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownCands");
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownIndices");
+ deleteFile(tmpDirPrefix + "temp.stats.unknown");
+ deleteFile(tmpDirPrefix + "temp.stats.mergedKnown");
+
+ // cleanupMemory();
+
+ println("Processed " + totalCandidateCount + " distinct candidates " + "(about "
+ + totalCandidateCount / numSentences + " per sentence):", 1);
+ for (int it = firstIt; it <= iteration; ++it) {
+ println("newCandidatesAdded[it=" + it + "] = " + newCandidatesAdded[it] + " (about "
+ + newCandidatesAdded[it] / numSentences + " per sentence)", 1);
+ }
+
+ println("", 1);
+
+ println("Number of features observed so far: " + numParams);
+ println("", 1);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // n-best list converges
+ if (newCandidatesAdded[iteration] == 0) {
+ if (!oneModificationPerIteration) {
+ println("No new candidates added in this iteration; exiting MIRA.", 1);
+ println("", 1);
+ println("--- MIRA iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ if (returnBest) {
+ // note that bestLambda.size() <= lambda.size()
+ for (int p = 1; p < bestLambda.size(); ++p)
+ lambda.set(p, bestLambda.get(p));
+ // and set the rest of lambda to be 0
+ for (int p = 0; p < lambda.size() - bestLambda.size(); ++p)
+ lambda.set(p + bestLambda.size(), new Double(0));
+ }
+
+ return null; // this means that the old values should be kept by the caller
+ } else {
+ println("Note: No new candidates added in this iteration.", 1);
+ }
+ }
+
+ /************* start optimization **************/
+
+ /*
+ * for( int v=1; v<initialLambda[1].length; v++ ) System.out.print(initialLambda[1][v]+" ");
+ * System.exit(0);
+ */
+
+ Optimizer.sentNum = numSentences; // total number of training sentences
+ Optimizer.needShuffle = needShuffle;
+ Optimizer.miraIter = miraIter;
+ Optimizer.oraSelectMode = oraSelectMode;
+ Optimizer.predSelectMode = predSelectMode;
+ Optimizer.runPercep = runPercep;
+ Optimizer.C = C;
+ Optimizer.needAvg = needAvg;
+ // Optimizer.sentForScale = sentForScale;
+ Optimizer.scoreRatio = scoreRatio;
+ Optimizer.evalMetric = evalMetric;
+ Optimizer.normalizationOptions = normalizationOptions;
+ Optimizer.needScale = needScale;
+ Optimizer.batchSize = batchSize;
+
+ // if need to use bleu stats history
+ if (iteration == 1) {
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ Optimizer.initBleuHistory(numSentences, evalMetric.get_suffStatsCount());
+ Optimizer.usePseudoBleu = usePseudoBleu;
+ Optimizer.R = R;
+ }
+ if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ Optimizer.initBleuHistory(numSentences, evalMetric.get_suffStatsCount() - 2); // Stats
+ // count of
+ // TER=2
+ Optimizer.usePseudoBleu = usePseudoBleu;
+ Optimizer.R = R;
+ }
+ }
+
+ Vector<String> output = new Vector<String>();
+
+ // note: initialLambda[] has length = numParamsOld
+ // augmented with new feature weights, initial values are 0
+ double[] initialLambdaNew = new double[1 + numParams];
+ System.arraycopy(initialLambda, 1, initialLambdaNew, 1, numParamsOld);
+
+ // finalLambda[] has length = numParams (considering new features)
+ double[] finalLambda = new double[1 + numParams];
+
+ Optimizer opt = new Optimizer(output, isOptimizable, initialLambdaNew, feat_hash, stats_hash);
+ finalLambda = opt.runOptimizer();
+
+ if (returnBest) {
+ double metricScore = opt.getMetricScore();
+ if (!evalMetric.getToBeMinimized()) {
+ if (metricScore > prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ } else {
+ if (metricScore < prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ }
+ }
+
+ // System.out.println(finalLambda.length);
+ // for( int i=0; i<finalLambda.length-1; i++ )
+ // System.out.println(finalLambda[i+1]);
+
+ /************* end optimization **************/
+
+ for (int i = 0; i < output.size(); i++)
+ println(output.get(i));
+
+ // check if any parameter has been updated
+ boolean anyParamChanged = false;
+ boolean anyParamChangedSignificantly = false;
+
+ for (int c = 1; c <= numParams; ++c) {
+ if (finalLambda[c] != lambda.get(c)) {
+ anyParamChanged = true;
+ }
+ if (Math.abs(finalLambda[c] - lambda.get(c)) > stopSigValue) {
+ anyParamChangedSignificantly = true;
+ }
+ }
+
+ // System.arraycopy(finalLambda,1,lambda,1,numParams);
+
+ println("--- MIRA iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+
+ if (!anyParamChanged) {
+ println("No parameter value changed in this iteration; exiting MIRA.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // was an early stopping criterion satisfied?
+ boolean critSatisfied = false;
+ if (!anyParamChangedSignificantly && stopSigValue >= 0) {
+ println("Note: No parameter value changed significantly " + "(i.e. by more than "
+ + stopSigValue + ") in this iteration.", 1);
+ critSatisfied = true;
+ }
+
+ if (critSatisfied) {
+ ++earlyStop;
+ println("", 1);
+ } else {
+ earlyStop = 0;
+ }
+
+ // if min number of iterations executed, investigate if early exit should happen
+ if (iteration >= minIts && earlyStop >= stopMinIts) {
+ println("Some early stopping criteria has been observed " + "in " + stopMinIts
+ + " consecutive iterations; exiting MIRA.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // if max number of iterations executed, exit
+ if (iteration >= maxIts) {
+ println("Maximum number of MIRA iterations reached; exiting MIRA.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop
+ }
+
+ // use the new wt vector to decode the next iteration
+ // (interpolation with previous wt vector)
+ double interCoef = 1.0; // no interpolation for now
+ for (int i = 1; i <= numParams; i++)
+ lambda.set(i, interCoef * finalLambda[i] + (1 - interCoef) * lambda.get(i).doubleValue());
+
+ println("Next iteration will decode with lambda: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ // printMemoryUsage();
+ for (int i = 0; i < numSentences; ++i) {
+ suffStats_array[i].clear();
+ }
+ // cleanupMemory();
+ // println("",2);
+
+ retA[2] = 0; // i.e. this should NOT be the last iteration
+ done = true;
+
+ } // while (!done) // NOTE: this "loop" will only be carried out once
+
+ // delete .temp.stats.merged file, since it is not needed in the next
+ // iteration (it will be recreated from scratch)
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ retA[0] = FINAL_score;
+ retA[1] = earlyStop;
+ return retA;
+
+ } // run_single_iteration
+
+ private String lambdaToString(ArrayList<Double> lambdaA) {
+ String retStr = "{";
+ int featToPrint = numParams > 15 ? 15 : numParams;
+ // print at most the first 15 features
+
+ retStr += "(listing the first " + featToPrint + " lambdas)";
+ for (int c = 1; c <= featToPrint - 1; ++c) {
+ retStr += "" + String.format("%.4f", lambdaA.get(c).doubleValue()) + ", ";
+ }
+ retStr += "" + String.format("%.4f", lambdaA.get(numParams).doubleValue()) + "}";
+
+ return retStr;
+ }
+
+ private String[] run_decoder(int iteration) {
+ String[] retSA = new String[2];
+
+ // retsa saves the output file name(nbest-file)
+ // and the decoder type
+
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ // use fake decoder
+ if (fakeFileNameTemplate != null
+ && fileExists(fakeFileNamePrefix + iteration + fakeFileNameSuffix)) {
+ String fakeFileName = fakeFileNamePrefix + iteration + fakeFileNameSuffix;
+ println("Not running decoder; using " + fakeFileName + " instead.", 1);
+ /*
+ * if (fakeFileName.endsWith(".gz")) { copyFile(fakeFileName,decoderOutFileName+".gz");
+ * gunzipFile(decoderOutFileName+".gz"); } else { copyFile(fakeFileName,decoderOutFileName); }
+ */
+ retSA[0] = fakeFileName;
+ retSA[1] = "2";
+
+ } else {
+ println("Running external decoder...", 1);
+
+ try {
+ ArrayList<String> cmd = new ArrayList<String>();
+ cmd.add(decoderCommandFileName);
+
+ if (passIterationToDecoder)
+ cmd.add(Integer.toString(iteration));
+
+ ProcessBuilder pb = new ProcessBuilder(cmd);
+ // this merges the error and output streams of the subprocess
+ pb.redirectErrorStream(true);
+ Process p = pb.start();
+
+ // capture the sub-command's output
+ new StreamGobbler(p.getInputStream(), decVerbosity).start();
+
+ int decStatus = p.waitFor();
+ if (decStatus != validDecoderExitValue) {
+ throw new RuntimeException("Call to decoder returned " + decStatus + "; was expecting "
+ + validDecoderExitValue + ".");
+ }
+ } catch (IOException| InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ retSA[0] = decoderOutFileName;
+ retSA[1] = "1";
+
+ }
+
+ return retSA;
+ }
+
+ private void produceTempFiles(String nbestFileName, int iteration) {
+ try {
+ String sentsFileName = tmpDirPrefix + "temp.sents.it" + iteration;
+ String featsFileName = tmpDirPrefix + "temp.feats.it" + iteration;
+
+ FileOutputStream outStream_sents = new FileOutputStream(sentsFileName, false);
+ OutputStreamWriter outStreamWriter_sents = new OutputStreamWriter(outStream_sents, "utf8");
+ BufferedWriter outFile_sents = new BufferedWriter(outStreamWriter_sents);
+
+ PrintWriter outFile_feats = new PrintWriter(featsFileName);
+
+ InputStream inStream_nbest = null;
+ if (nbestFileName.endsWith(".gz")) {
+ inStream_nbest = new GZIPInputStream(new FileInputStream(nbestFileName));
+ } else {
+ inStream_nbest = new FileInputStream(nbestFileName);
+ }
+ BufferedReader inFile_nbest = new BufferedReader(
+ new InputStreamReader(inStream_nbest, "utf8"));
+
+ String line; // , prevLine;
+ String candidate_str = "";
+ String feats_str = "";
+
+ int i = 0;
+ int n = 0;
+ line = inFile_nbest.readLine();
+
+ while (line != null) {
+
+ /*
+ * line format:
- *
++ *
+ * i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
+ * .*
+ */
+
+ // in a well formed file, we'd find the nth candidate for the ith sentence
+
+ int read_i = Integer.parseInt((line.substring(0, line.indexOf("|||"))).trim());
+
+ if (read_i != i) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = (line.substring(line.indexOf("|||") + 3)).trim(); // get rid of initial text
+
+ candidate_str = (line.substring(0, line.indexOf("|||"))).trim();
+ feats_str = (line.substring(line.indexOf("|||") + 3)).trim();
+ // get rid of candidate string
+
+ int junk_i = feats_str.indexOf("|||");
+ if (junk_i >= 0) {
+ feats_str = (feats_str.substring(0, junk_i)).trim();
+ }
+
+ writeLine(normalize(candidate_str, textNormMethod), outFile_sents);
+ outFile_feats.println(feats_str);
+
+ ++n;
+ if (n == sizeOfNBest) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = inFile_nbest.readLine();
+ }
+
+ if (i != numSentences) { // last sentence had too few candidates
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ }
+
+ inFile_nbest.close();
+ outFile_sents.close();
+ outFile_feats.close();
+
+ if (compressFiles == 1) {
+ gzipFile(sentsFileName);
+ gzipFile(featsFileName);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ private void createConfigFile(ArrayList<Double> params, String cfgFileName,
+ String templateFileName) {
+ try {
+ // i.e. create cfgFileName, which is similar to templateFileName, but with
+ // params[] as parameter values
+
+ BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
+ PrintWriter outFile = new PrintWriter(cfgFileName);
+
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
+ int origFeatNum = 0; // feat num in the template file
+
+ String line = inFile.readLine();
+ while (line != null) {
+ int c_match = -1;
+ for (int c = 1; c <= numParams; ++c) {
+ if (line.startsWith(Vocabulary.word(c) + " ")) {
+ c_match = c;
+ ++origFeatNum;
+ break;
+ }
+ }
+
+ if (c_match == -1) {
+ outFile.println(line);
+ } else {
+ if (Math.abs(params.get(c_match).doubleValue()) > 1e-20)
+ outFile.println(Vocabulary.word(c_match) + " " + params.get(c_match));
+ }
+
+ line = inFile.readLine();
+ }
+
+ // now append weights of new features
+ for (int c = origFeatNum + 1; c <= numParams; ++c) {
+ if (Math.abs(params.get(c).doubleValue()) > 1e-20)
+ outFile.println(Vocabulary.word(c) + " " + params.get(c));
+ }
+
+ inFile.close();
+ outFile.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void processParamFile() {
+ // process parameter file
+ Scanner inFile_init = null;
+ try {
+ inFile_init = new Scanner(new FileReader(paramsFileName));
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ String dummy = "";
+
+ // initialize lambda[] and other related arrays
+ for (int c = 1; c <= numParams; ++c) {
+ // skip parameter name
+ while (!dummy.equals("|||")) {
+ dummy = inFile_init.next();
+ }
+
+ // read default value
+ lambda.set(c, inFile_init.nextDouble());
+ defaultLambda[c] = lambda.get(c).doubleValue();
+
+ // read isOptimizable
+ dummy = inFile_init.next();
+ if (dummy.equals("Opt")) {
+ isOptimizable[c] = true;
+ } else if (dummy.equals("Fix")) {
+ isOptimizable[c] = false;
+ } else {
+ throw new RuntimeException("Unknown isOptimizable string " + dummy + " (must be either Opt or Fix)");
+ }
+
+ if (!isOptimizable[c]) { // skip next two values
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ } else {
+ // the next two values are not used, only to be consistent with ZMERT's params file format
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ // set minRandValue[c] and maxRandValue[c] (range for random values)
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("minRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ minRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("maxRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ maxRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ // check for illogical values
+ if (minRandValue[c] > maxRandValue[c]) {
+ throw new RuntimeException("minRandValue[" + c + "]=" + minRandValue[c]
+ + " > " + maxRandValue[c] + "=maxRandValue[" + c + "]!");
+ }
+
+ // check for odd values
+ if (minRandValue[c] == maxRandValue[c]) {
+ println("Warning: lambda[" + c + "] has " + "minRandValue = maxRandValue = "
+ + minRandValue[c] + ".", 1);
+ }
+ } // if (!isOptimizable[c])
+
+ /*
+ * precision[c] = inFile_init.nextDouble(); if (precision[c] < 0) { println("precision[" + c +
+ * "]=" + precision[c] + " < 0! Must be non-negative."); System.exit(21); }
+ */
+
+ }
+
+ // set normalizationOptions[]
+ String origLine = "";
+ while (origLine != null && origLine.length() == 0) {
+ origLine = inFile_init.nextLine();
+ }
+
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ // normalization = none
+ // normalization = absval 1 lm
+ // normalization = maxabsval 1
+ // normalization = minabsval 1
+ // normalization = LNorm 2 1
+
+ dummy = (origLine.substring(origLine.indexOf("=") + 1)).trim();
+ String[] dummyA = dummy.split("\\s+");
+
+ if (dummyA[0].equals("none")) {
+ normalizationOptions[0] = 0;
+ } else if (dummyA[0].equals("absval")) {
+ normalizationOptions[0] = 1;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ String pName = dummyA[2];
+ for (int i = 3; i < dummyA.length; ++i) { // in case parameter name has multiple words
+ pName = pName + " " + dummyA[i];
+ }
+ normalizationOptions[2] = Vocabulary.id(pName);
+
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the absval normalization method must be positive.");
+ }
+ if (normalizationOptions[2] == 0) {
+ throw new RuntimeException("Unrecognized feature name " + normalizationOptions[2]
+ + " for absval normalization method.");
+ }
+ } else if (dummyA[0].equals("maxabsval")) {
+ normalizationOptions[0] = 2;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the maxabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("minabsval")) {
+ normalizationOptions[0] = 3;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the minabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("LNorm")) {
+ normalizationOptions[0] = 4;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ normalizationOptions[2] = Double.parseDouble(dummyA[2]);
+ if (normalizationOptions[1] <= 0 || normalizationOptions[2] <= 0) {
+ throw new RuntimeException("Both values for the LNorm normalization method must be"
+ + " positive.");
+ }
+ } else {
+ throw new RuntimeException("Unrecognized normalization method " + dummyA[0] + "; "
+ + "must be one of none, absval, maxabsval, and LNorm.");
+ } // if (dummyA[0])
+
+ inFile_init.close();
+ } // processParamFile()
+
+ private void processDocInfo() {
+ // sets numDocuments and docOfSentence[]
+ docOfSentence = new int[numSentences];
+
+ if (docInfoFileName == null) {
+ for (int i = 0; i < numSentences; ++i)
+ docOfSentence[i] = 0;
+ numDocuments = 1;
+ } else {
+
+ try {
+
+ // 4 possible formats:
+ // 1) List of numbers, one per document, indicating # sentences in each document.
+ // 2) List of "docName size" pairs, one per document, indicating name of document and #
+ // sentences.
+ // 3) List of docName's, one per sentence, indicating which doument each sentence belongs
+ // to.
+ // 4) List of docName_number's, one per sentence, indicating which doument each sentence
+ // belongs to,
+ // and its order in that document. (can also use '-' instead of '_')
+
- int docInfoSize = countNonEmptyLines(docInfoFileName);
++ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
+
+ if (docInfoSize < numSentences) { // format #1 or #2
+ numDocuments = docInfoSize;
+ int i = 0;
+
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ String line = inFile.readLine();
+ boolean format1 = (!(line.contains(" ")));
+
+ for (int doc = 0; doc < numDocuments; ++doc) {
+
+ if (doc != 0)
+ line = inFile.readLine();
+
+ int docSize = 0;
+ if (format1) {
+ docSize = Integer.parseInt(line);
+ } else {
+ docSize = Integer.parseInt(line.split("\\s+")[1]);
+ }
+
+ for (int i2 = 1; i2 <= docSize; ++i2) {
+ docOfSentence[i] = doc;
+ ++i;
+ }
+
+ }
+
+ // now i == numSentences
+
+ inFile.close();
+
+ } else if (docInfoSize == numSentences) { // format #3 or #4
+
+ boolean format3 = false;
+
+ HashSet<String> seenStrings = new HashSet<String>();
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ // set format3 = true if a duplicate is found
+ String line = inFile.readLine();
+ if (seenStrings.contains(line))
+ format3 = true;
+ seenStrings.add(line);
+ }
+
+ inFile.close();
+
+ HashSet<String> seenDocNames = new HashSet<String>();
+ HashMap<String, Integer> docOrder = new HashMap<String, Integer>();
+ // maps a document name to the order (0-indexed) in which it was seen
+
+ inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ String line = inFile.readLine();
+
+ String docName = "";
+ if (format3) {
+ docName = line;
+ } else {
+ int sep_i = Math.max(line.lastIndexOf('_'), line.lastIndexOf('-'));
+ docName = line.substring(0, sep_i);
+ }
+
+ if (!seenDocNames.contains(docName)) {
+ seenDocNames.add(docName);
+ docOrder.put(docName, seenDocNames.size() - 1);
+ }
+
+ int docOrder_i = docOrder.get(docName);
+
+ docOfSentence[i] = docOrder_i;
+
+ }
+
+ inFile.close();
+
+ numDocuments = seenDocNames.size();
+
+ } else { // badly formatted
+
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private boolean copyFile(String origFileName, String newFileName) {
+ try {
+ File inputFile = new File(origFileName);
+ File outputFile = new File(newFileName);
+
+ InputStream in = new FileInputStream(inputFile);
+ OutputStream out = new FileOutputStream(outputFile);
+
+ byte[] buffer = new byte[1024];
+ int len;
+ while ((len = in.read(buffer)) > 0) {
+ out.write(buffer, 0, len);
+ }
+ in.close();
+ out.close();
+
+ /*
+ * InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
+ * new BufferedReader(new InputStreamReader(inStream, "utf8"));
- *
++ *
+ * FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
+ * outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
+ * BufferedWriter(outStreamWriter);
- *
++ *
+ * String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
- *
++ *
+ * inFile.close(); outFile.close();
+ */
+ return true;
+ } catch (IOException e) {
+ LOG.error(e.getMessage(), e);
+ return false;
+ }
+ }
+
+ private void renameFile(String origFileName, String newFileName) {
+ if (fileExists(origFileName)) {
+ deleteFile(newFileName);
+ File oldFile = new File(origFileName);
+ File newFile = new File(newFileName);
+ if (!oldFile.renameTo(newFile)) {
+ println("Warning: attempt to rename " + origFileName + " to " + newFileName
+ + " was unsuccessful!", 1);
+ }
+ } else {
+ println("Warning: file " + origFileName + " does not exist! (in MIRACore.renameFile)", 1);
+ }
+ }
+
+ private void deleteFile(String fileName) {
+ if (fileExists(fileName)) {
+ File fd = new File(fileName);
+ if (!fd.delete()) {
+ println("Warning: attempt to delete " + fileName + " was unsuccessful!", 1);
+ }
+ }
+ }
+
+ private void writeLine(String line, BufferedWriter writer) throws IOException {
+ writer.write(line, 0, line.length());
+ writer.newLine();
+ writer.flush();
+ }
+
+ // need to re-write to handle different forms of lambda
+ public void finish() {
+ if (myDecoder != null) {
+ myDecoder.cleanUp();
+ }
+
+ // create config file with final values
+ createConfigFile(lambda, decoderConfigFileName + ".MIRA.final", decoderConfigFileName
+ + ".MIRA.orig");
+
+ // delete current decoder config file and decoder output
+ deleteFile(decoderConfigFileName);
+ deleteFile(decoderOutFileName);
+
+ // restore original name for config file (name was changed
+ // in initialize() so it doesn't get overwritten)
+ renameFile(decoderConfigFileName + ".MIRA.orig", decoderConfigFileName);
+
+ if (finalLambdaFileName != null) {
+ try {
+ PrintWriter outFile_lambdas = new PrintWriter(finalLambdaFileName);
+ for (int c = 1; c <= numParams; ++c) {
+ outFile_lambdas.println(Vocabulary.word(c) + " ||| " + lambda.get(c).doubleValue());
+ }
+ outFile_lambdas.close();
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private String[] cfgFileToArgsArray(String fileName) {
+ checkFile(fileName);
+
+ Vector<String> argsVector = new Vector<String>();
+
+ BufferedReader inFile = null;
+ try {
+ inFile = new BufferedReader(new FileReader(fileName));
+ String line, origLine;
+ do {
+ line = inFile.readLine();
+ origLine = line; // for error reporting purposes
+
+ if (line != null && line.length() > 0 && line.charAt(0) != '#') {
+
+ if (line.indexOf("#") != -1) { // discard comment
+ line = line.substring(0, line.indexOf("#"));
+ }
+
+ line = line.trim();
+
+ // now line should look like "-xxx XXX"
+
+ /*
+ * OBSOLETE MODIFICATION //SPECIAL HANDLING FOR MIRA CLASSIFIER PARAMETERS String[] paramA
+ * = line.split("\\s+");
- *
++ *
+ * if( paramA[0].equals("-classifierParams") ) { String classifierParam = ""; for(int p=1;
+ * p<=paramA.length-1; p++) classifierParam += paramA[p]+" ";
- *
++ *
+ * if(paramA.length>=2) { String[] tmpParamA = new String[2]; tmpParamA[0] = paramA[0];
+ * tmpParamA[1] = classifierParam; paramA = tmpParamA; } else {
+ * println("Malformed line in config file:"); println(origLine); System.exit(70); } }//END
+ * MODIFICATION
+ */
+
+ // cmu modification(from meteor for zmert)
+ // Parse args
+ ArrayList<String> argList = new ArrayList<String>();
+ StringBuilder arg = new StringBuilder();
+ boolean quoted = false;
+ for (int i = 0; i < line.length(); i++) {
+ if (Character.isWhitespace(line.charAt(i))) {
+ if (quoted)
+ arg.append(line.charAt(i));
+ else if (arg.length() > 0) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ } else if (line.charAt(i) == '\'') {
+ if (quoted) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ quoted = !quoted;
+ } else
+ arg.append(line.charAt(i));
+ }
+ if (arg.length() > 0)
+ argList.add(arg.toString());
+ // Create paramA
+ String[] paramA = new String[argList.size()];
+ for (int i = 0; i < paramA.length; paramA[i] = argList.get(i++))
+ ;
+ // END CMU MODIFICATION
+
+ if (paramA.length == 2 && paramA[0].charAt(0) == '-') {
+ argsVector.add(paramA[0]);
+ argsVector.add(paramA[1]);
+ } else if (paramA.length > 2 && (paramA[0].equals("-m") || paramA[0].equals("-docSet"))) {
+ // -m (metricName), -docSet are allowed to have extra optinos
+ for (int opt = 0; opt < paramA.length; ++opt) {
+ argsVector.add(paramA[opt]);
+ }
+ } else {
+ throw new RuntimeException("Malformed line in config file:" + origLine);
+ }
+
+ }
+ } while (line != null);
+
+ inFile.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ String[] argsArray = new String[argsVector.size()];
+
+ for (int i = 0; i < argsVector.size(); ++i) {
+ argsArray[i] = argsVector.elementAt(i);
+ }
+
+ return argsArray;
+ }
+
+ private void processArgsArray(String[] args) {
+ processArgsArray(args, true);
+ }
+
+ private void processArgsArray(String[] args, boolean firstTime) {
+ /* set default values */
+ // Relevant files
+ dirPrefix = null;
+ sourceFileName = null;
+ refFileName = "reference.txt";
+ refsPerSen = 1;
+ textNormMethod = 1;
+ paramsFileName = "params.txt";
+ docInfoFileName = null;
+ finalLambdaFileName = null;
+ // MERT specs
+ metricName = "BLEU";
+ metricName_display = metricName;
+ metricOptions = new String[2];
+ metricOptions[0] = "4";
+ metricOptions[1] = "closest";
+ docSubsetInfo = new int[7];
+ docSubsetInfo[0] = 0;
+ maxMERTIterations = 20;
+ prevMERTIterations = 20;
+ minMERTIterations = 5;
+ stopMinIts = 3;
+ stopSigValue = -1;
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ numOptThreads = 1;
+ saveInterFiles = 3;
+ compressFiles = 0;
+ oneModificationPerIteration = false;
+ randInit = false;
+ seed = System.currentTimeMillis();
+ // useDisk = 2;
+ // Decoder specs
+ decoderCommandFileName = null;
+ passIterationToDecoder = false;
+ decoderOutFileName = "output.nbest";
+ validDecoderExitValue = 0;
+ decoderConfigFileName = "dec_cfg.txt";
+ sizeOfNBest = 100;
+ fakeFileNameTemplate = null;
+ fakeFileNamePrefix = null;
+ fakeFileNameSuffix = null;
+ // Output specs
+ verbosity = 1;
+ decVerbosity = 0;
+
+ int i = 0;
+
+ while (i < args.length) {
+ String option = args[i];
+ // Relevant files
+ if (option.equals("-dir")) {
+ dirPrefix = args[i + 1];
+ } else if (option.equals("-s")) {
+ sourceFileName = args[i + 1];
+ } else if (option.equals("-r")) {
+ refFileName = args[i + 1];
+ } else if (option.equals("-rps")) {
+ refsPerSen = Integer.parseInt(args[i + 1]);
+ if (refsPerSen < 1) {
+ throw new RuntimeException("refsPerSen must be positive.");
+ }
+ } else if (option.equals("-txtNrm")) {
+ textNormMethod = Integer.parseInt(args[i + 1]);
+ if (textNormMethod < 0 || textNormMethod > 4) {
+ throw new RuntimeException("textNormMethod should be between 0 and 4");
+ }
+ } else if (option.equals("-p")) {
+ paramsFileName = args[i + 1];
+ } else if (option.equals("-docInfo")) {
+ docInfoFileName = args[i + 1];
+ } else if (option.equals("-fin")) {
+ finalLambdaFileName = args[i + 1];
+ // MERT specs
+ } else if (option.equals("-m")) {
+ metricName = args[i + 1];
+ metricName_display = metricName;
+ if (EvaluationMetric.knownMetricName(metricName)) {
+ int optionCount = EvaluationMetric.metricOptionCount(metricName);
+ metricOptions = new String[optionCount];
+ for (int opt = 0; opt < optionCount; ++opt) {
+ metricOptions[opt] = args[i + opt + 2];
+ }
+ i += optionCount;
+ } else {
+ throw new RuntimeException("Unknown metric name " + metricName + ".");
+ }
+ } else if (option.equals("-docSet")) {
+ String method = args[i + 1];
+
+ if (method.equals("all")) {
+ docSubsetInfo[0] = 0;
+ i += 0;
+ } else if (method.equals("bottom")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 1;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 2;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("top")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 3;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 4;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("window")) {
+ String a1 = args[i + 2];
+ a1 = a1.substring(0, a1.indexOf("d")); // size of window
+ String a2 = args[i + 4];
+ if (a2.indexOf("p") > 0) {
+ docSubsetInfo[0] = 5;
+ a2 = a2.substring(0, a2.indexOf("p"));
+ } else {
+ docSubsetInfo[0] = 6;
+ a2 = a2.substring(0, a2.indexOf("r"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a1);
+ docSubsetInfo[6]
<TRUNCATED>
[15/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
index 16c25cd,0000000..6ad85a8
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/adagrad/Optimizer.java
@@@ -1,716 -1,0 +1,712 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.adagrad;
+
- import java.util.Collections;
+import java.util.ArrayList;
++import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.Vector;
- import java.lang.Math;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.metrics.EvaluationMetric;
+
+// this class implements the AdaGrad algorithm
+public class Optimizer {
+ public Optimizer(Vector<String>_output, boolean[] _isOptimizable, double[] _initialLambda,
+ HashMap<String, String>[] _feat_hash, HashMap<String, String>[] _stats_hash) {
+ output = _output; // (not used for now)
+ isOptimizable = _isOptimizable;
+ initialLambda = _initialLambda; // initial weights array
- paramDim = initialLambda.length - 1;
++ paramDim = initialLambda.length - 1;
+ initialLambda = _initialLambda;
+ feat_hash = _feat_hash; // feature hash table
+ stats_hash = _stats_hash; // suff. stats hash table
+ finalLambda = new double[initialLambda.length];
+ System.arraycopy(initialLambda, 0, finalLambda, 0, finalLambda.length);
+ }
+
+ //run AdaGrad for one epoch
+ public double[] runOptimizer() {
+ List<Integer> sents = new ArrayList<>();
+ for( int i = 0; i < sentNum; ++i )
+ sents.add(i);
+ double[] avgLambda = new double[initialLambda.length]; //only needed if averaging is required
+ for( int i = 0; i < initialLambda.length; ++i )
+ avgLambda[i] = 0;
+ for ( int iter = 0; iter < adagradIter; ++iter ) {
+ System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
+ if(needShuffle)
+ Collections.shuffle(sents);
-
++
+ double oraMetric, oraScore, predMetric, predScore;
+ double[] oraPredScore = new double[4];
+ double loss = 0;
+ double diff = 0;
+ double sumMetricScore = 0;
+ double sumModelScore = 0;
+ String oraFeat = "";
+ String predFeat = "";
+ String[] oraPredFeat = new String[2];
+ String[] vecOraFeat;
+ String[] vecPredFeat;
+ String[] featInfo;
- int thisBatchSize = 0;
+ int numBatch = 0;
+ int numUpdate = 0;
- Iterator it;
++ Iterator<Integer> it;
+ Integer diffFeatId;
+
+ //update weights
+ Integer s;
+ int sentCount = 0;
+ double prevLambda = 0;
+ double diffFeatVal = 0;
+ double oldVal = 0;
+ double gdStep = 0;
+ double Hii = 0;
+ double gradiiSquare = 0;
+ int lastUpdateTime = 0;
+ HashMap<Integer, Integer> lastUpdate = new HashMap<>();
+ HashMap<Integer, Double> lastVal = new HashMap<>();
+ HashMap<Integer, Double> H = new HashMap<>();
+ while( sentCount < sentNum ) {
+ loss = 0;
- thisBatchSize = batchSize;
+ ++numBatch;
+ HashMap<Integer, Double> featDiff = new HashMap<>();
+ for(int b = 0; b < batchSize; ++b ) {
+ //find out oracle and prediction
+ s = sents.get(sentCount);
+ findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
-
++
+ //the model scores here are already scaled in findOraPred
+ oraMetric = oraPredScore[0];
+ oraScore = oraPredScore[1];
+ predMetric = oraPredScore[2];
+ predScore = oraPredScore[3];
+ oraFeat = oraPredFeat[0];
+ predFeat = oraPredFeat[1];
-
++
+ //update the scale
+ if(needScale) { //otherwise featscale remains 1.0
+ sumMetricScore += Math.abs(oraMetric + predMetric);
+ //restore the original model score
+ sumModelScore += Math.abs(oraScore + predScore) / featScale;
-
++
+ if(sumModelScore/sumMetricScore > scoreRatio)
+ featScale = sumMetricScore/sumModelScore;
+ }
+ // processedSent++;
-
++
+ vecOraFeat = oraFeat.split("\\s+");
+ vecPredFeat = predFeat.split("\\s+");
+
+ //accumulate difference feature vector
+ if ( b == 0 ) {
+ for (String aVecOraFeat : vecOraFeat) {
+ featInfo = aVecOraFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (String aVecPredFeat : vecPredFeat) {
+ featInfo = aVecPredFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) - Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the 2nd feature vector
+ featDiff.put(diffFeatId, -1.0 * Double.parseDouble(featInfo[1]));
+ }
+ } else {
+ for (String aVecOraFeat : vecOraFeat) {
+ featInfo = aVecOraFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) + Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the new oracle feature vector
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (String aVecPredFeat : vecPredFeat) {
+ featInfo = aVecPredFeat.split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId) - Double.parseDouble(featInfo[1]);
+ if (Math.abs(diff) > 1e-20)
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ } else //features only firing in the new prediction feature vector
+ featDiff.put(diffFeatId, -1.0 * Double.parseDouble(featInfo[1]));
+ }
+ }
+
+ //remember the model scores here are already scaled
+ double singleLoss = evalMetric.getToBeMinimized() ?
- (predMetric-oraMetric) - (oraScore-predScore)/featScale:
++ (predMetric-oraMetric) - (oraScore-predScore)/featScale:
+ (oraMetric-predMetric) - (oraScore-predScore)/featScale;
+ if(singleLoss > 0)
+ loss += singleLoss;
+ ++sentCount;
+ if( sentCount >= sentNum ) {
- thisBatchSize = b + 1;
+ break;
+ }
+ } //for(int b : batchSize)
+
+ //System.out.println("\n\n"+sentCount+":");
+
+ if( loss > 0 ) {
+ //if(true) {
+ ++numUpdate;
+ //update weights (see Duchi'11, Eq.23. For l1-reg, use lazy update)
+ Set<Integer> diffFeatSet = featDiff.keySet();
+ it = diffFeatSet.iterator();
+ while(it.hasNext()) { //note these are all non-zero gradients!
- diffFeatId = (Integer)it.next();
++ diffFeatId = it.next();
+ diffFeatVal = -1.0 * featDiff.get(diffFeatId); //gradient
+ if( regularization > 0 ) {
+ lastUpdateTime =
+ lastUpdate.get(diffFeatId) == null ? 0 : lastUpdate.get(diffFeatId);
+ if( lastUpdateTime < numUpdate - 1 ) {
+ //haven't been updated (gradient=0) for at least 2 steps
+ //lazy compute prevLambda now
+ oldVal =
+ lastVal.get(diffFeatId) == null ? initialLambda[diffFeatId] : lastVal.get(diffFeatId);
+ Hii =
+ H.get(diffFeatId) == null ? 0 : H.get(diffFeatId);
+ if(Math.abs(Hii) > 1e-20) {
+ if( regularization == 1 )
+ prevLambda =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numBatch - 1 - lastUpdateTime) / Hii );
+ else if( regularization == 2 ) {
+ prevLambda =
+ Math.pow( Hii/(lam+Hii), (numUpdate - 1 - lastUpdateTime) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = prevLambda;
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - 1 - lastUpdateTime; ++t ) {
+ avgLambda[diffFeatId] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ } else {
+ if( regularization == 1 )
+ prevLambda = 0;
+ else if( regularization == 2 )
+ prevLambda = oldVal;
+ }
+ } else //just updated at last time step or just started
+ prevLambda = finalLambda[diffFeatId];
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ //update the weight
+ if( regularization == 1 ) {
+ gdStep = prevLambda - eta * diffFeatVal / Hii;
+ finalLambda[diffFeatId] = Math.signum(gdStep) * clip( Math.abs(gdStep) - lam * eta / Hii );
+ } else if(regularization == 2 ) {
+ finalLambda[diffFeatId] = (Hii * prevLambda - eta * diffFeatVal) / (lam + Hii);
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ lastUpdate.put(diffFeatId, numUpdate);
+ lastVal.put(diffFeatId, finalLambda[diffFeatId]);
+ } else { //if no regularization
+ if(H.get(diffFeatId) != null) {
+ gradiiSquare = H.get(diffFeatId);
+ gradiiSquare *= gradiiSquare;
+ gradiiSquare += diffFeatVal * diffFeatVal;
+ Hii = Math.sqrt(gradiiSquare);
+ } else
+ Hii = Math.abs(diffFeatVal);
+ H.put(diffFeatId, Hii);
+ finalLambda[diffFeatId] = finalLambda[diffFeatId] - eta * diffFeatVal / Hii;
+ if(needAvg)
+ avgLambda[diffFeatId] += finalLambda[diffFeatId];
+ }
+ } //while(it.hasNext())
+ } //if(loss > 0)
+ else { //no loss, therefore the weight update is skipped
+ //however, the avg weights still need to be accumulated
+ if( regularization == 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ avgLambda[i] += finalLambda[i];
+ } else if( regularization == 2 ) {
+ if(needAvg) {
+ //due to lazy update, we need to figure out the actual
+ //weight vector at this point first...
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ if( lastUpdate.get(i) != null ) {
+ if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ //lazy compute
+ avgLambda[i] +=
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ } else
+ avgLambda[i] += finalLambda[i];
+ }
+ avgLambda[i] += finalLambda[i];
+ }
+ }
+ }
+ }
+ } //while( sentCount < sentNum )
+ if( regularization > 0 ) {
+ for( int i = 1; i < finalLambda.length; ++i ) {
+ //now lazy compute those weights that haven't been taken care of
+ if( lastUpdate.get(i) == null )
+ finalLambda[i] = 0;
+ else if( lastUpdate.get(i) < numUpdate ) {
+ oldVal = lastVal.get(i);
+ Hii = H.get(i);
+ if( regularization == 1 )
+ finalLambda[i] =
+ Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numUpdate - lastUpdate.get(i)) / Hii );
+ else if( regularization == 2 ) {
- finalLambda[i] =
++ finalLambda[i] =
+ Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
+ if(needAvg) { //fill the gap due to lazy update
+ double prevLambdaCopy = finalLambda[i];
+ double scale = Hii/(lam+Hii);
+ for( int t = 0; t < numUpdate - lastUpdate.get(i); ++t ) {
+ avgLambda[i] += prevLambdaCopy;
+ prevLambdaCopy /= scale;
+ }
+ }
+ }
+ }
+ if( regularization == 2 && needAvg ) {
+ if( iter == adagradIter - 1 )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+ } else { //if no regularization
+ if( iter == adagradIter - 1 && needAvg ) {
+ for( int i = 1; i < finalLambda.length; ++i )
+ finalLambda[i] = avgLambda[i] / ( numBatch * adagradIter );
+ }
+ }
+
+ double initMetricScore;
+ if (iter == 0) {
+ initMetricScore = computeCorpusMetricScore(initialLambda);
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ } else {
+ initMetricScore = finalMetricScore;
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ }
+ // prepare the printing info
+ String result = " Initial "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", initMetricScore) + " Final "
+ + evalMetric.get_metricName() + "=" + String.format("%.4f", finalMetricScore);
+ //print lambda info
+ // int numParamToPrint = 0;
+ // numParamToPrint = paramDim > 10 ? 10 : paramDim; // how many parameters
+ // // to print
+ // result = paramDim > 10 ? "Final lambda (first 10): {" : "Final lambda: {";
-
++
+ // for (int i = 1; i <= numParamToPrint; ++i)
+ // result += String.format("%.4f", finalLambda[i]) + " ";
+
+ output.add(result);
+ } //for ( int iter = 0; iter < adagradIter; ++iter ) {
+
+ //non-optimizable weights should remain unchanged
+ ArrayList<Double> cpFixWt = new ArrayList<>();
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] )
+ cpFixWt.add(finalLambda[i]);
+ }
+ normalizeLambda(finalLambda);
+ int countNonOpt = 0;
+ for ( int i = 1; i < isOptimizable.length; ++i ) {
+ if ( ! isOptimizable[i] ) {
+ finalLambda[i] = cpFixWt.get(countNonOpt);
+ ++countNonOpt;
+ }
+ }
+ return finalLambda;
+ }
+
+ private double clip(double x) {
+ return x > 0 ? x : 0;
+ }
+
+ public double computeCorpusMetricScore(double[] finalLambda) {
+ int suffStatsCount = evalMetric.get_suffStatsCount();
+ double modelScore;
+ double maxModelScore;
+ Set<String> candSet;
+ String candStr;
+ String[] feat_str;
+ String[] tmpStatsVal = new String[suffStatsCount];
+ int[] corpusStatsVal = new int[suffStatsCount];
+ for (int i = 0; i < suffStatsCount; i++)
+ corpusStatsVal[i] = 0;
+
+ for (int i = 0; i < sentNum; i++) {
+ candSet = feat_hash[i].keySet();
+
+ // find out the 1-best candidate for each sentence
+ // this depends on the training mode
+ maxModelScore = NegInf;
+ for (String aCandSet : candSet) {
+ modelScore = 0.0;
+ candStr = aCandSet.toString();
+
+ feat_str = feat_hash[i].get(candStr).split("\\s+");
+
+ String[] feat_info;
+
+ for (String aFeat_str : feat_str) {
+ feat_info = aFeat_str.split("=");
+ modelScore += Double.parseDouble(feat_info[1]) * finalLambda[Vocabulary.id(feat_info[0])];
+ }
+
+ if (maxModelScore < modelScore) {
+ maxModelScore = modelScore;
+ tmpStatsVal = stats_hash[i].get(candStr).split("\\s+"); // save the
+ // suff stats
+ }
+ }
+
+ for (int j = 0; j < suffStatsCount; j++)
+ corpusStatsVal[j] += Integer.parseInt(tmpStatsVal[j]); // accumulate
+ // corpus-leve
+ // suff stats
+ } // for( int i=0; i<sentNum; i++ )
+
+ return evalMetric.score(corpusStatsVal);
+ }
-
++
+ private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat, double[] lambda, double featScale)
+ {
+ double oraMetric=0, oraScore=0, predMetric=0, predScore=0;
+ String oraFeat="", predFeat="";
+ double candMetric = 0, candScore = 0; //metric and model scores for each cand
+ Set<String> candSet = stats_hash[sentId].keySet();
+ String cand = "";
+ String feats = "";
+ String oraCand = ""; //only used when BLEU/TER-BLEU is used as metric
+ String[] featStr;
+ String[] featInfo;
-
++
+ int actualFeatId;
+ double bestOraScore;
+ double worstPredScore;
-
++
+ if(oraSelectMode==1)
+ bestOraScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ bestOraScore = PosInf; //smaller score will be selected
+ else
+ bestOraScore = NegInf;
+ }
-
++
+ if(predSelectMode==1 || predSelectMode==2)
+ worstPredScore = NegInf; //larger score will be selected
+ else {
+ if(evalMetric.getToBeMinimized())
+ worstPredScore = NegInf; //larger score will be selected
+ else
+ worstPredScore = PosInf;
+ }
+
+ for (String aCandSet : candSet) {
+ cand = aCandSet.toString();
+ candMetric = computeSentMetric(sentId, cand); //compute metric score
+
+ //start to compute model score
+ candScore = 0;
+ featStr = feat_hash[sentId].get(cand).split("\\s+");
+ feats = "";
+
+ for (String aFeatStr : featStr) {
+ featInfo = aFeatStr.split("=");
+ actualFeatId = Vocabulary.id(featInfo[0]);
+ candScore += Double.parseDouble(featInfo[1]) * lambda[actualFeatId];
+ if ((actualFeatId < isOptimizable.length && isOptimizable[actualFeatId])
+ || actualFeatId >= isOptimizable.length)
+ feats += actualFeatId + "=" + Double.parseDouble(featInfo[1]) + " ";
+ }
+
+ candScore *= featScale; //scale the model score
+
+ //is this cand oracle?
+ if (oraSelectMode == 1) {//"hope", b=1, r=1
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (bestOraScore <= (candScore - candMetric)) {
+ bestOraScore = candScore - candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= (candScore + candMetric)) {
+ bestOraScore = candScore + candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ } else {//best metric score(ex: max BLEU), b=1, r=0
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (bestOraScore >= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+
+ //is this cand prediction?
+ if (predSelectMode == 1) {//"fear"
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (worstPredScore <= (candScore + candMetric)) {
+ worstPredScore = candScore + candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore <= (candScore - candMetric)) {
+ worstPredScore = candScore - candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ } else if (predSelectMode == 2) {//model prediction(max model score)
+ if (worstPredScore <= candScore) {
+ worstPredScore = candScore;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {//worst metric score(ex: min BLEU)
+ if (evalMetric.getToBeMinimized()) {//if the smaller the metric score, the better
+ if (worstPredScore <= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore >= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ }
-
++
+ oraPredScore[0] = oraMetric;
+ oraPredScore[1] = oraScore;
+ oraPredScore[2] = predMetric;
+ oraPredScore[3] = predScore;
+ oraPredFeat[0] = oraFeat;
+ oraPredFeat[1] = predFeat;
-
++
+ //update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j]);
+ }
-
++
+ if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu ) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j+2]); //the first 2 stats are TER stats
+ }
+ }
-
++
+ // compute *sentence-level* metric score for cand
+ private double computeSentMetric(int sentId, String cand) {
+ String statString;
+ String[] statVal_str;
+ int[] statVal = new int[evalMetric.get_suffStatsCount()];
+
+ statString = stats_hash[sentId].get(cand);
+ statVal_str = statString.split("\\s+");
+
+ if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = (int) (Integer.parseInt(statVal_str[j]) + bleuHistory[sentId][j]);
+ } else if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount()-2; j++)
+ statVal[j+2] = (int)(Integer.parseInt(statVal_str[j+2]) + bleuHistory[sentId][j]); //only modify the BLEU stats part(TER has 2 stats)
+ } else { //in all other situations, use normal stats
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = Integer.parseInt(statVal_str[j]);
+ }
+
+ return evalMetric.score(statVal);
+ }
+
+ // from ZMERT
+ private void normalizeLambda(double[] origLambda) {
+ // private String[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ int normalizationMethod = (int) normalizationOptions[0];
+ double scalingFactor = 1.0;
+ if (normalizationMethod == 0) {
+ scalingFactor = 1.0;
+ } else if (normalizationMethod == 1) {
+ int c = (int) normalizationOptions[2];
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[c]);
+ } else if (normalizationMethod == 2) {
+ double maxAbsVal = -1;
+ int maxAbsVal_c = 0;
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) > maxAbsVal) {
+ maxAbsVal = Math.abs(origLambda[c]);
+ maxAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[maxAbsVal_c]);
+
+ } else if (normalizationMethod == 3) {
+ double minAbsVal = PosInf;
+ int minAbsVal_c = 0;
+
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) < minAbsVal) {
+ minAbsVal = Math.abs(origLambda[c]);
+ minAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[minAbsVal_c]);
+
+ } else if (normalizationMethod == 4) {
+ double pow = normalizationOptions[1];
+ double norm = L_norm(origLambda, pow);
+ scalingFactor = normalizationOptions[2] / norm;
+ }
+
+ for (int c = 1; c <= paramDim; ++c) {
+ origLambda[c] *= scalingFactor;
+ }
+ }
+
+ // from ZMERT
+ private double L_norm(double[] A, double pow) {
+ // calculates the L-pow norm of A[]
+ // NOTE: this calculation ignores A[0]
+ double sum = 0.0;
+ for (int i = 1; i < A.length; ++i)
+ sum += Math.pow(Math.abs(A[i]), pow);
+
+ return Math.pow(sum, 1 / pow);
+ }
+
+ public static double getScale()
+ {
+ return featScale;
+ }
-
++
+ public static void initBleuHistory(int sentNum, int statCount)
+ {
+ bleuHistory = new double[sentNum][statCount];
+ for(int i=0; i<sentNum; i++) {
+ for(int j=0; j<statCount; j++) {
+ bleuHistory[i][j] = 0.0;
+ }
+ }
+ }
+
+ public double getMetricScore()
+ {
+ return finalMetricScore;
+ }
-
++
+ private final Vector<String> output;
+ private double[] initialLambda;
+ private final double[] finalLambda;
+ private double finalMetricScore;
+ private final HashMap<String, String>[] feat_hash;
+ private final HashMap<String, String>[] stats_hash;
+ private final int paramDim;
+ private final boolean[] isOptimizable;
+ public static int sentNum;
+ public static int adagradIter; //AdaGrad internal iterations
+ public static int oraSelectMode;
+ public static int predSelectMode;
+ public static int batchSize;
+ public static int regularization;
+ public static boolean needShuffle;
+ public static boolean needScale;
+ public static double scoreRatio;
+ public static boolean needAvg;
+ public static boolean usePseudoBleu;
+ public static double featScale = 1.0; //scale the features in order to make the model score comparable with metric score
+ //updates in each epoch if necessary
+ public static double eta;
+ public static double lam;
- public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
++ public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static EvaluationMetric evalMetric;
+ public static double[] normalizationOptions;
+ public static double[][] bleuHistory;
-
++
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
index 10efdc6,0000000..27303ec
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
+++ b/joshua-core/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
@@@ -1,412 -1,0 +1,414 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.corpus.syntax;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.util.io.LineReader;
+
+public class ArraySyntaxTree implements SyntaxTree, Externalizable {
+
+ /**
+ * Note that index stores the indices of lattice node positions, i.e. the last element of index is
+ * the terminal node, pointing to lattice.size()
+ */
+ private ArrayList<Integer> forwardIndex;
+ private ArrayList<Integer> forwardLattice;
+ private ArrayList<Integer> backwardIndex;
+ private ArrayList<Integer> backwardLattice;
+
+ private ArrayList<Integer> terminals;
+
+ private final boolean useBackwardLattice = true;
+
+ private static final int MAX_CONCATENATIONS = 3;
+ private static final int MAX_LABELS = 100;
+
+ public ArraySyntaxTree() {
+ forwardIndex = null;
+ forwardLattice = null;
+ backwardIndex = null;
+ backwardLattice = null;
+
+ terminals = null;
+ }
+
+
+ public ArraySyntaxTree(String parsed_line) {
+ initialize();
+ appendFromPennFormat(parsed_line);
+ }
+
+
+ /**
+ * Returns a collection of single-non-terminal labels that exactly cover the specified span in the
+ * lattice.
+ */
++ @Override
+ public Collection<Integer> getConstituentLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+ int span_length = to - from;
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span == span_length)
+ labels.add(forwardLattice.get(i));
+ else if (current_span < span_length) break;
+ }
+ return labels;
+ }
+
+
+ public int getOneConstituent(int from, int to) {
+ int spanLength = to - from;
+ Stack<Integer> stack = new Stack<>();
+
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int currentSpan = forwardLattice.get(i + 1);
+ if (currentSpan == spanLength) {
+ return forwardLattice.get(i);
+ } else if (currentSpan < spanLength) break;
+ }
+ if (stack.isEmpty()) return 0;
+ StringBuilder sb = new StringBuilder();
+ while (!stack.isEmpty()) {
+ String w = Vocabulary.word(stack.pop());
+ if (sb.length() != 0) sb.append(":");
+ sb.append(w);
+ }
+ String label = sb.toString();
+ return Vocabulary.id(adjustMarkup(label));
+ }
+
+
+ public int getOneSingleConcatenation(int from, int to) {
+ for (int midpt = from + 1; midpt < to; midpt++) {
+ int x = getOneConstituent(from, midpt);
+ if (x == 0) continue;
+ int y = getOneConstituent(midpt, to);
+ if (y == 0) continue;
+ String label = Vocabulary.word(x) + "+" + Vocabulary.word(y);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ public int getOneDoubleConcatenation(int from, int to) {
+ for (int a = from + 1; a < to - 1; a++) {
+ for (int b = a + 1; b < to; b++) {
+ int x = getOneConstituent(from, a);
+ if (x == 0) continue;
+ int y = getOneConstituent(a, b);
+ if (y == 0) continue;
+ int z = getOneConstituent(b, to);
+ if (z == 0) continue;
+ String label = Vocabulary.word(x) + "+" + Vocabulary.word(y) + "+" + Vocabulary.word(z);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ }
+ return 0;
+ }
+
+
+ public int getOneRightSideCCG(int from, int to) {
+ for (int end = to + 1; end <= forwardLattice.size(); end++) {
+ int x = getOneConstituent(from, end);
+ if (x == 0) continue;
+ int y = getOneConstituent(to, end);
+ if (y == 0) continue;
+ String label = Vocabulary.word(x) + "/" + Vocabulary.word(y);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ public int getOneLeftSideCCG(int from, int to) {
+ for (int start = from - 1; start >= 0; start--) {
+ int x = getOneConstituent(start, to);
+ if (x == 0) continue;
+ int y = getOneConstituent(start, from);
+ if (y == 0) continue;
+ String label = Vocabulary.word(y) + "\\" + Vocabulary.word(x);
+ return Vocabulary.id(adjustMarkup(label));
+ }
+ return 0;
+ }
+
+
+ /**
+ * Returns a collection of concatenated non-terminal labels that exactly cover the specified span
+ * in the lattice. The number of non-terminals concatenated is limited by MAX_CONCATENATIONS and
+ * the total number of labels returned is bounded by MAX_LABELS.
+ */
++ @Override
+ public Collection<Integer> getConcatenatedLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+
+ int span_length = to - from;
+ Stack<Integer> nt_stack = new Stack<>();
+ Stack<Integer> pos_stack = new Stack<>();
+ Stack<Integer> depth_stack = new Stack<>();
+
+ // seed stacks (reverse order to save on iterations, longer spans)
+ for (int i = forwardIndex.get(from + 1) - 2; i >= forwardIndex.get(from); i -= 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span < span_length) {
+ nt_stack.push(forwardLattice.get(i));
+ pos_stack.push(from + current_span);
+ depth_stack.push(1);
+ } else if (current_span >= span_length) break;
+ }
+
+ while (!nt_stack.isEmpty() && labels.size() < MAX_LABELS) {
+ int nt = nt_stack.pop();
+ int pos = pos_stack.pop();
+ int depth = depth_stack.pop();
+
+ // maximum depth reached without filling span
+ if (depth == MAX_CONCATENATIONS) continue;
+
+ int remaining_span = to - pos;
+ for (int i = forwardIndex.get(pos + 1) - 2; i >= forwardIndex.get(pos); i -= 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span > remaining_span) break;
+
+ // create and look up concatenated label
+ int concatenated_nt =
+ Vocabulary.id(adjustMarkup(Vocabulary.word(nt) + "+"
+ + Vocabulary.word(forwardLattice.get(i))));
+ if (current_span < remaining_span) {
+ nt_stack.push(concatenated_nt);
+ pos_stack.push(pos + current_span);
+ depth_stack.push(depth + 1);
+ } else if (current_span == remaining_span) {
+ labels.add(concatenated_nt);
+ }
+ }
+ }
+
+ return labels;
+ }
+
+ // TODO: can pre-comupute all that in top-down fashion.
++ @Override
+ public Collection<Integer> getCcgLabels(int from, int to) {
+ Collection<Integer> labels = new HashSet<>();
+
+ int span_length = to - from;
+ // TODO: range checks on the to and from
+
+ boolean is_prefix = (forwardLattice.get(forwardIndex.get(from) + 1) > span_length);
+ if (is_prefix) {
+ Map<Integer, Set<Integer>> main_constituents = new HashMap<>();
+ // find missing to the right
+ for (int i = forwardIndex.get(from); i < forwardIndex.get(from + 1); i += 2) {
+ int current_span = forwardLattice.get(i + 1);
+ if (current_span <= span_length)
+ break;
+ else {
+ int end_pos = forwardLattice.get(i + 1) + from;
+ Set<Integer> nts = main_constituents.get(end_pos);
+ if (nts == null) main_constituents.put(end_pos, new HashSet<>());
+ main_constituents.get(end_pos).add(forwardLattice.get(i));
+ }
+ }
+ for (int i = forwardIndex.get(to); i < forwardIndex.get(to + 1); i += 2) {
+ Set<Integer> main_set = main_constituents.get(to + forwardLattice.get(i + 1));
+ if (main_set != null) {
+ for (int main : main_set)
+ labels.add(Vocabulary.id(adjustMarkup(Vocabulary.word(main) + "/"
+ + Vocabulary.word(forwardLattice.get(i)))));
+ }
+ }
+ }
+
+ if (!is_prefix) {
+ if (useBackwardLattice) {
+ // check if there is any possible higher-level constituent overlapping
+ int to_end =
+ (to == backwardIndex.size() - 1) ? backwardLattice.size() : backwardIndex.get(to + 1);
+ // check longest span ending in to..
+ if (backwardLattice.get(to_end - 1) <= span_length) return labels;
+
+ Map<Integer, Set<Integer>> main_constituents = new HashMap<>();
+ // find missing to the left
+ for (int i = to_end - 2; i >= backwardIndex.get(to); i -= 2) {
+ int current_span = backwardLattice.get(i + 1);
+ if (current_span <= span_length)
+ break;
+ else {
+ int start_pos = to - backwardLattice.get(i + 1);
+ Set<Integer> nts = main_constituents.get(start_pos);
+ if (nts == null) main_constituents.put(start_pos, new HashSet<>());
+ main_constituents.get(start_pos).add(backwardLattice.get(i));
+ }
+ }
+ for (int i = backwardIndex.get(from); i < backwardIndex.get(from + 1); i += 2) {
+ Set<Integer> main_set = main_constituents.get(from - backwardLattice.get(i + 1));
+ if (main_set != null) {
+ for (int main : main_set)
+ labels.add(Vocabulary.id(adjustMarkup(Vocabulary.word(main) + "\\"
+ + Vocabulary.word(backwardLattice.get(i)))));
+ }
+ }
+ } else {
+ // TODO: bothersome no-backwards-arrays method.
+ }
+ }
+ return labels;
+ }
+
+ @Override
+ public int[] getTerminals() {
+ return getTerminals(0, terminals.size());
+ }
+
+ @Override
+ public int[] getTerminals(int from, int to) {
+ int[] span = new int[to - from];
+ for (int i = from; i < to; i++)
+ span[i - from] = terminals.get(i);
+ return span;
+ }
+
++ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // TODO Auto-generated method stub
+ }
+
++ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // TODO Auto-generated method stub
+ }
+
+ /**
+ * Reads Penn Treebank format file
+ * @param file_name the string path of the Penn Treebank file
+ * @throws IOException if the file does not exist
+ */
+ public void readExternalText(String file_name) throws IOException {
- LineReader reader = new LineReader(file_name);
- initialize();
- for (String line : reader) {
- if (line.trim().equals("")) continue;
- appendFromPennFormat(line);
++ try (LineReader reader = new LineReader(file_name);) {
++ initialize();
++ for (String line : reader) {
++ if (line.trim().equals("")) continue;
++ appendFromPennFormat(line);
++ }
+ }
+ }
+
- public void writeExternalText(String file_name) throws IOException {
- // TODO Auto-generated method stub
- }
-
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < forwardIndex.size(); i++)
+ sb.append("FI[").append(i).append("] =\t").append(forwardIndex.get(i)).append("\n");
+ sb.append("\n");
+ for (int i = 0; i < forwardLattice.size(); i += 2)
+ sb.append("F[").append(i).append("] =\t").append(Vocabulary.word(forwardLattice.get(i)))
+ .append(" , ").append(forwardLattice.get(i + 1)).append("\n");
+
+ sb.append("\n");
+ for (int i = 0; i < terminals.size(); i += 1)
+ sb.append("T[").append(i).append("] =\t").append(Vocabulary.word(terminals.get(i)))
+ .append(" , 1 \n");
+
+ if (this.useBackwardLattice) {
+ sb.append("\n");
+ for (int i = 0; i < backwardIndex.size(); i++)
+ sb.append("BI[").append(i).append("] =\t").append(backwardIndex.get(i)).append("\n");
+ sb.append("\n");
+ for (int i = 0; i < backwardLattice.size(); i += 2)
+ sb.append("B[").append(i).append("] =\t").append(Vocabulary.word(backwardLattice.get(i)))
+ .append(" , ").append(backwardLattice.get(i + 1)).append("\n");
+ }
+ return sb.toString();
+ }
+
+
+ private void initialize() {
+ forwardIndex = new ArrayList<>();
+ forwardIndex.add(0);
+ forwardLattice = new ArrayList<>();
+ if (this.useBackwardLattice) {
+ backwardIndex = new ArrayList<>();
+ backwardIndex.add(0);
+ backwardLattice = new ArrayList<>();
+ }
+
+ terminals = new ArrayList<>();
+ }
+
+
+ // TODO: could make this way more efficient
+ private void appendFromPennFormat(String line) {
+ String[] tokens = line.replaceAll("\\(", " ( ").replaceAll("\\)", " ) ").trim().split("\\s+");
+
+ boolean next_nt = false;
+ int current_id = 0;
+ Stack<Integer> stack = new Stack<>();
+
+ for (String token : tokens) {
+ if ("(".equals(token)) {
+ next_nt = true;
+ continue;
+ }
+ if (")".equals(token)) {
+ int closing_pos = stack.pop();
+ forwardLattice.set(closing_pos, forwardIndex.size() - forwardLattice.get(closing_pos));
+ if (this.useBackwardLattice) {
+ backwardLattice.add(forwardLattice.get(closing_pos - 1));
+ backwardLattice.add(forwardLattice.get(closing_pos));
+ }
+ continue;
+ }
+ if (next_nt) {
+ // get NT id
+ current_id = Vocabulary.id(adjustMarkup(token));
+ // add into lattice
+ forwardLattice.add(current_id);
+ // push NT span field onto stack (added hereafter, we're just saving the "- 1")
+ stack.push(forwardLattice.size());
+ // add NT span field
+ forwardLattice.add(forwardIndex.size());
+ } else {
+ current_id = Vocabulary.id(token);
+ terminals.add(current_id);
+
+ forwardIndex.add(forwardLattice.size());
+ if (this.useBackwardLattice) backwardIndex.add(backwardLattice.size());
+ }
+ next_nt = false;
+ }
+ }
+
+ private String adjustMarkup(String nt) {
+ return "[" + nt.replaceAll("[\\[\\]]", "") + "]";
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
index 26ed674,0000000..97baa27
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/ArgsParser.java
@@@ -1,116 -1,0 +1,116 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * @author orluke
- *
++ *
+ */
+public class ArgsParser {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ArgsParser.class);
+
+ private String configFile = null;
+
+ /**
+ * Parse the arguments passed from the command line when the JoshuaDecoder application was
+ * executed from the command line.
- *
++ *
+ * @param args string array of input arguments
+ * @param config the {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ * @throws IOException if there is an error wit the input arguments
+ */
+ public ArgsParser(String[] args, JoshuaConfiguration config) throws IOException {
+
+ /*
+ * Look for a verbose flag, -v.
- *
- * Look for an argument to the "-config" flag to find the config file, if any.
++ *
++ * Look for an argument to the "-config" flag to find the config file, if any.
+ */
+ if (args.length >= 1) {
+ // Search for a verbose flag
+ for (int i = 0; i < args.length; i++) {
+ if (args[i].equals("-v")) {
+ Decoder.VERBOSE = Integer.parseInt(args[i + 1].trim());
+ config.setVerbosity(Decoder.VERBOSE);
+ }
-
- if (args[i].equals("-version")) {
- LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));
- reader.readLine();
- String version = reader.readLine().split("\\s+")[2];
- System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
- System.out.println("joshua.incubator.apache.org");
- System.exit(0);
+
++ if (args[i].equals("-version")) {
++ try (LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));) {
++ reader.readLine();
++ String version = reader.readLine().split("\\s+")[2];
++ System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
++ System.out.println("joshua.incubator.apache.org");
++ System.exit(0);
++ }
+ } else if (args[i].equals("-license")) {
+ try {
+ Files.readAllLines(Paths.get(String.format("%s/../LICENSE",
+ JoshuaConfiguration.class.getProtectionDomain().getCodeSource().getLocation()
+ .getPath())), Charset.defaultCharset()).forEach(System.out::println);
+ } catch (IOException e) {
+ throw new RuntimeException("FATAL: missing license file!", e);
+ }
+ System.exit(0);
+ }
+ }
+
+ // Search for the configuration file from the end (so as to take the last one)
+ for (int i = args.length-1; i >= 0; i--) {
+ if (args[i].equals("-c") || args[i].equals("-config")) {
+
+ setConfigFile(args[i + 1].trim());
+ try {
+ LOG.info("Parameters read from configuration file: {}", getConfigFile());
+ config.readConfigFile(getConfigFile());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ break;
+ }
+ }
+
+ // Now process all the command-line args
+ config.processCommandLineOptions(args);
+ }
+ }
+
+ /**
+ * @return the configFile
+ */
+ public String getConfigFile() {
+ return configFile;
+ }
+
+ /**
+ * @param configFile the configFile to set
+ */
+ public void setConfigFile(String configFile) {
+ this.configFile = configFile;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
index 9cfb6eb,0000000..3d6f3bc
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/Decoder.java
@@@ -1,598 -1,0 +1,597 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature;
+import static org.apache.joshua.decoder.ff.tm.OwnerMap.getOwner;
+import static org.apache.joshua.util.Constants.spaceSeparator;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+
- import com.google.common.base.Strings;
- import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureMap;
+import org.apache.joshua.decoder.ff.FeatureVector;
+import org.apache.joshua.decoder.ff.PhraseModel;
+import org.apache.joshua.decoder.ff.StatefulFF;
+import org.apache.joshua.decoder.ff.lm.LanguageModelFF;
+import org.apache.joshua.decoder.ff.tm.Grammar;
+import org.apache.joshua.decoder.ff.tm.OwnerId;
+import org.apache.joshua.decoder.ff.tm.OwnerMap;
+import org.apache.joshua.decoder.ff.tm.Rule;
+import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
+import org.apache.joshua.decoder.ff.tm.hash_based.MemoryBasedBatchGrammar;
+import org.apache.joshua.decoder.ff.tm.packed.PackedGrammar;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
+import org.apache.joshua.decoder.phrase.PhraseTable;
+import org.apache.joshua.decoder.segment_file.Sentence;
+import org.apache.joshua.util.FileUtility;
+import org.apache.joshua.util.FormatUtils;
+import org.apache.joshua.util.Regex;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
++import com.google.common.base.Strings;
++import com.google.common.util.concurrent.ThreadFactoryBuilder;
++
+/**
+ * This class handles decoder initialization and the complication introduced by multithreading.
+ *
+ * After initialization, the main entry point to the Decoder object is
+ * decodeAll(TranslationRequest), which returns a set of Translation objects wrapped in an iterable
+ * TranslationResponseStream object. It is important that we support multithreading both (a) across the sentences
+ * within a request and (b) across requests, in a round-robin fashion. This is done by maintaining a
+ * fixed sized concurrent thread pool. When a new request comes in, a RequestParallelizer thread is
+ * launched. This object iterates over the request's sentences, obtaining a thread from the
+ * thread pool, and using that thread to decode the sentence. If a decoding thread is not available,
+ * it will block until one is in a fair (FIFO) manner. RequestParallelizer thereby permits intra-request
+ * parallelization by separating out reading the input stream from processing the translated sentences,
+ * but also ensures that round-robin parallelization occurs, since RequestParallelizer uses the
+ * thread pool before translating each request.
+ *
+ * A decoding thread is handled by DecoderTask and launched from DecoderThreadRunner. The purpose
+ * of the runner is to record where to place the translated sentence when it is done (i.e., which
+ * TranslationResponseStream object). TranslationResponseStream itself is an iterator whose next() call blocks until the next
+ * translation is available.
+ *
+ * @author Matt Post post@cs.jhu.edu
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @author Lane Schwartz dowobeha@users.sourceforge.net
+ * @author Kellen Sunderland kellen.sunderland@gmail.com
+ */
+public class Decoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(Decoder.class);
+
+ private final JoshuaConfiguration joshuaConfiguration;
+
+ public JoshuaConfiguration getJoshuaConfiguration() {
+ return joshuaConfiguration;
+ }
+
+ /*
+ * Many of these objects themselves are global objects. We pass them in when constructing other
+ * objects, so that they all share pointers to the same object. This is good because it reduces
+ * overhead, but it can be problematic because of unseen dependencies (for example, in the
+ * Vocabulary shared by language model, translation grammar, etc).
+ */
+ private final List<Grammar> grammars = new ArrayList<Grammar>();
+ private final ArrayList<FeatureFunction> featureFunctions = new ArrayList<>();
+ private Grammar customPhraseTable = null;
+
+ /* The feature weights. */
+ public static FeatureVector weights;
+
+ public static int VERBOSE = 1;
+
+ /**
+ * Constructor method that creates a new decoder using the specified configuration file.
+ *
+ * @param joshuaConfiguration a populated {@link org.apache.joshua.decoder.JoshuaConfiguration}
+ */
+ public Decoder(JoshuaConfiguration joshuaConfiguration) {
+ this.joshuaConfiguration = joshuaConfiguration;
+ this.initialize();
+ }
+
+ /**
+ * This function is the main entry point into the decoder. It translates all the sentences in a
+ * (possibly boundless) set of input sentences. Each request launches its own thread to read the
+ * sentences of the request.
+ *
+ * @param request the populated {@link TranslationRequestStream}
+ * @throws RuntimeException if any fatal errors occur during translation
+ * @return an iterable, asynchronously-filled list of TranslationResponseStream
+ */
+ public TranslationResponseStream decodeAll(TranslationRequestStream request) {
+ TranslationResponseStream results = new TranslationResponseStream(request);
+ CompletableFuture.runAsync(() -> decodeAllAsync(request, results));
+ return results;
+ }
+
+ private void decodeAllAsync(TranslationRequestStream request,
+ TranslationResponseStream responseStream) {
+
+ // Give the threadpool a friendly name to help debuggers
+ final ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ .setNameFormat("TranslationWorker-%d")
+ .setDaemon(true)
+ .build();
+ ExecutorService executor = Executors.newFixedThreadPool(this.joshuaConfiguration.num_parallel_decoders,
+ threadFactory);
+ try {
+ for (; ; ) {
+ Sentence sentence = request.next();
+
+ if (sentence == null) {
+ break;
+ }
+
+ executor.execute(() -> {
+ try {
+ Translation result = decode(sentence);
+ responseStream.record(result);
+ } catch (Throwable ex) {
+ responseStream.propagate(ex);
+ }
+ });
+ }
+ responseStream.finish();
+ } finally {
+ executor.shutdown();
+ }
+ }
+
+
+ /**
+ * We can also just decode a single sentence in the same thread.
+ *
+ * @param sentence {@link org.apache.joshua.lattice.Lattice} input
+ * @throws RuntimeException if any fatal errors occur during translation
+ * @return the sentence {@link org.apache.joshua.decoder.Translation}
+ */
+ public Translation decode(Sentence sentence) {
+ DecoderTask decoderTask = new DecoderTask(this.grammars, this.featureFunctions, joshuaConfiguration);
+ return decoderTask.translate(sentence);
+ }
+
+ /**
+ * Clean shutdown of Decoder, resetting all
+ * static variables, such that any other instance of Decoder
+ * afterwards gets a fresh start.
+ */
+ public void cleanUp() {
+ resetGlobalState();
+ }
+
+ public static void resetGlobalState() {
+ // clear/reset static variables
+ OwnerMap.clear();
+ FeatureMap.clear();
+ Vocabulary.clear();
+ Vocabulary.unregisterLanguageModels();
+ LanguageModelFF.resetLmIndex();
+ StatefulFF.resetGlobalStateIndex();
+ }
+
+ public static void writeConfigFile(double[] newWeights, String template, String outputFile,
+ String newDiscriminativeModel) {
+ try {
+ int columnID = 0;
+
+ try (LineReader reader = new LineReader(template);
+ BufferedWriter writer = FileUtility.getWriteFileStream(outputFile)) {
+ for (String line : reader) {
+ line = line.trim();
+ if (Regex.commentOrEmptyLine.matches(line) || line.contains("=")) {
+ // comment, empty line, or parameter lines: just copy
+ writer.write(line);
+ writer.newLine();
+
+ } else { // models: replace the weight
+ String[] fds = Regex.spaces.split(line);
+ StringBuilder newSent = new StringBuilder();
+ if (!Regex.floatingNumber.matches(fds[fds.length - 1])) {
+ throw new IllegalArgumentException("last field is not a number; the field is: "
+ + fds[fds.length - 1]);
+ }
+
+ if (newDiscriminativeModel != null && "discriminative".equals(fds[0])) {
+ newSent.append(fds[0]).append(' ');
+ newSent.append(newDiscriminativeModel).append(' ');// change the
+ // file name
+ for (int i = 2; i < fds.length - 1; i++) {
+ newSent.append(fds[i]).append(' ');
+ }
+ } else {// regular
+ for (int i = 0; i < fds.length - 1; i++) {
+ newSent.append(fds[i]).append(' ');
+ }
+ }
+ if (newWeights != null)
+ newSent.append(newWeights[columnID++]);// change the weight
+ else
+ newSent.append(fds[fds.length - 1]);// do not change
+
+ writer.write(newSent.toString());
+ writer.newLine();
+ }
+ }
+ }
+
+ if (newWeights != null && columnID != newWeights.length) {
+ throw new IllegalArgumentException("number of models does not match number of weights");
+ }
+
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Initialize all parts of the JoshuaDecoder.
+ */
+ private void initialize() {
+ try {
+
+ long pre_load_time = System.currentTimeMillis();
+ resetGlobalState();
+
+ /* Weights can be listed in a separate file (denoted by parameter "weights-file") or directly
+ * in the Joshua config file. Config file values take precedent.
+ */
+ this.readWeights(joshuaConfiguration.weights_file);
-
-
++
++
+ /* Add command-line-passed weights to the weights array for processing below */
+ if (!Strings.isNullOrEmpty(joshuaConfiguration.weight_overwrite)) {
+ String[] tokens = joshuaConfiguration.weight_overwrite.split("\\s+");
+ for (int i = 0; i < tokens.length; i += 2) {
+ String feature = tokens[i];
+ float value = Float.parseFloat(tokens[i+1]);
+
+ if (joshuaConfiguration.moses)
+ feature = demoses(feature);
+
+ joshuaConfiguration.weights.add(String.format("%s %s", feature, tokens[i+1]));
+ LOG.info("COMMAND LINE WEIGHT: {} -> {}", feature, value);
+ }
+ }
+
+ /* Read the weights found in the config file */
+ for (String pairStr: joshuaConfiguration.weights) {
+ String pair[] = pairStr.split("\\s+");
+
+ /* Sanity check for old-style unsupported feature invocations. */
+ if (pair.length != 2) {
+ String errMsg = "FATAL: Invalid feature weight line found in config file.\n" +
+ String.format("The line was '%s'\n", pairStr) +
+ "You might be using an old version of the config file that is no longer supported\n" +
+ "Check joshua.apache.org or email dev@joshua.apache.org for help\n" +
+ "Code = " + 17;
+ throw new RuntimeException(errMsg);
+ }
+
+ weights.add(hashFeature(pair[0]), Float.parseFloat(pair[1]));
+ }
+
+ LOG.info("Read {} weights", weights.size());
+
+ // Do this before loading the grammars and the LM.
+ this.featureFunctions.clear();
+
+ // Initialize and load grammars. This must happen first, since the vocab gets defined by
+ // the packed grammar (if any)
+ this.initializeTranslationGrammars();
+ LOG.info("Grammar loading took: {} seconds.",
+ (System.currentTimeMillis() - pre_load_time) / 1000);
+
+ // Initialize the features: requires that LM model has been initialized.
+ this.initializeFeatureFunctions();
+
+ // This is mostly for compatibility with the Moses tuning script
+ if (joshuaConfiguration.show_weights_and_quit) {
+ for (Entry<Integer, Float> entry : weights.entrySet()) {
+ System.out.println(String.format("%s=%.5f", FeatureMap.getFeature(entry.getKey()), entry.getValue()));
+ }
+ // TODO (fhieber): this functionality should not be in main Decoder class and simply exit.
+ System.exit(0);
+ }
+
+ // Sort the TM grammars (needed to do cube pruning)
+ if (joshuaConfiguration.amortized_sorting) {
+ LOG.info("Grammar sorting happening lazily on-demand.");
+ } else {
+ long pre_sort_time = System.currentTimeMillis();
+ for (Grammar grammar : this.grammars) {
+ grammar.sortGrammar(this.featureFunctions);
+ }
+ LOG.info("Grammar sorting took {} seconds.",
+ (System.currentTimeMillis() - pre_sort_time) / 1000);
+ }
+
+ } catch (IOException e) {
+ LOG.warn(e.getMessage(), e);
+ }
+ }
+
+ /**
+ * Initializes translation grammars Retained for backward compatibility
+ *
+ * @throws IOException Several grammar elements read from disk that can
+ * cause IOExceptions.
+ */
+ private void initializeTranslationGrammars() throws IOException {
+
+ if (joshuaConfiguration.tms.size() > 0) {
+
+ // collect packedGrammars to check if they use a shared vocabulary
+ final List<PackedGrammar> packed_grammars = new ArrayList<>();
+
+ // tm = {thrax/hiero,packed,samt,moses} OWNER LIMIT FILE
+ for (String tmLine : joshuaConfiguration.tms) {
+
+ String type = tmLine.substring(0, tmLine.indexOf(' '));
+ String[] args = tmLine.substring(tmLine.indexOf(' ')).trim().split("\\s+");
+ HashMap<String, String> parsedArgs = FeatureFunction.parseArgs(args);
+
+ String owner = parsedArgs.get("owner");
+ int span_limit = Integer.parseInt(parsedArgs.get("maxspan"));
+ String path = parsedArgs.get("path");
+
+ Grammar grammar;
+ if (! type.equals("moses") && ! type.equals("phrase")) {
+ if (new File(path).isDirectory()) {
+ try {
+ PackedGrammar packed_grammar = new PackedGrammar(path, span_limit, owner, type, joshuaConfiguration);
+ packed_grammars.add(packed_grammar);
+ grammar = packed_grammar;
+ } catch (FileNotFoundException e) {
+ String msg = String.format("Couldn't load packed grammar from '%s'", path)
+ + "Perhaps it doesn't exist, or it may be an old packed file format.";
+ throw new RuntimeException(msg);
+ }
+ } else {
+ // thrax, hiero, samt
+ grammar = new MemoryBasedBatchGrammar(type, path, owner,
+ joshuaConfiguration.default_non_terminal, span_limit, joshuaConfiguration);
+ }
+
+ } else {
+
+ joshuaConfiguration.search_algorithm = "stack";
+ grammar = new PhraseTable(path, owner, type, joshuaConfiguration);
+ }
+
+ this.grammars.add(grammar);
+ }
+
+ checkSharedVocabularyChecksumsForPackedGrammars(packed_grammars);
+
+ } else {
+ LOG.warn("no grammars supplied! Supplying dummy glue grammar.");
+ MemoryBasedBatchGrammar glueGrammar = new MemoryBasedBatchGrammar("glue", joshuaConfiguration, -1);
+ glueGrammar.addGlueRules(featureFunctions);
+ this.grammars.add(glueGrammar);
+ }
-
++
+ /* Add the grammar for custom entries */
+ if (joshuaConfiguration.search_algorithm.equals("stack"))
+ this.customPhraseTable = new PhraseTable("custom", joshuaConfiguration);
+ else
+ this.customPhraseTable = new MemoryBasedBatchGrammar("custom", joshuaConfiguration, 20);
+ this.grammars.add(this.customPhraseTable);
-
++
+ /* Create an epsilon-deleting grammar */
+ if (joshuaConfiguration.lattice_decoding) {
+ LOG.info("Creating an epsilon-deleting grammar");
+ MemoryBasedBatchGrammar latticeGrammar = new MemoryBasedBatchGrammar("lattice", joshuaConfiguration, -1);
+ HieroFormatReader reader = new HieroFormatReader(OwnerMap.register("lattice"));
+
+ String goalNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.goal_symbol);
+ String defaultNT = FormatUtils.cleanNonTerminal(joshuaConfiguration.default_non_terminal);
+
+ //FIXME: arguments changed to match string format on best effort basis. Author please review.
+ String ruleString = String.format("[%s] ||| [%s,1] <eps> ||| [%s,1] ||| ", goalNT, defaultNT, defaultNT);
+
+ Rule rule = reader.parseLine(ruleString);
+ latticeGrammar.addRule(rule);
+ rule.estimateRuleCost(featureFunctions);
+
+ this.grammars.add(latticeGrammar);
+ }
+
+ /* Now create a feature function for each owner */
+ final Set<OwnerId> ownersSeen = new HashSet<>();
+
+ for (Grammar grammar: this.grammars) {
+ OwnerId owner = grammar.getOwner();
+ if (! ownersSeen.contains(owner)) {
+ this.featureFunctions.add(
+ new PhraseModel(
+ weights, new String[] { "tm", "-owner", getOwner(owner) }, joshuaConfiguration, grammar));
+ ownersSeen.add(owner);
+ }
+ }
+
+ LOG.info("Memory used {} MB",
+ ((Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()) / 1000000.0));
+ }
+
+ /**
+ * Checks if multiple packedGrammars have the same vocabulary by comparing their vocabulary file checksums.
+ */
+ private static void checkSharedVocabularyChecksumsForPackedGrammars(final List<PackedGrammar> packed_grammars) {
+ String previous_checksum = "";
+ for (PackedGrammar grammar : packed_grammars) {
+ final String checksum = grammar.computeVocabularyChecksum();
+ if (previous_checksum.isEmpty()) {
+ previous_checksum = checksum;
+ } else {
+ if (!checksum.equals(previous_checksum)) {
+ throw new RuntimeException(
+ "Trying to load multiple packed grammars with different vocabularies!" +
+ "Have you packed them jointly?");
+ }
+ previous_checksum = checksum;
+ }
+ }
+ }
+
+ /*
+ * This function reads the weights for the model. Feature names and their weights are listed one
+ * per line in the following format:
- *
++ *
+ * FEATURE_NAME WEIGHT
+ */
+ private void readWeights(String fileName) {
+ Decoder.weights = new FeatureVector(5);
+
+ if (fileName.equals(""))
+ return;
+
- try {
- LineReader lineReader = new LineReader(fileName);
-
++ try (LineReader lineReader = new LineReader(fileName);) {
+ for (String line : lineReader) {
+ line = line.replaceAll(spaceSeparator, " ");
+
+ if (line.equals("") || line.startsWith("#") || line.startsWith("//")
+ || line.indexOf(' ') == -1)
+ continue;
+
+ String tokens[] = line.split(spaceSeparator);
+ String feature = tokens[0];
+ Float value = Float.parseFloat(tokens[1]);
+
+ // Kludge for compatibility with Moses tuners
+ if (joshuaConfiguration.moses) {
+ feature = demoses(feature);
+ }
+
+ weights.add(hashFeature(feature), value);
+ }
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ LOG.info("Read {} weights from file '{}'", weights.size(), fileName);
+ }
+
+ private String demoses(String feature) {
+ if (feature.endsWith("="))
+ feature = feature.replace("=", "");
+ if (feature.equals("OOV_Penalty"))
+ feature = "OOVPenalty";
+ else if (feature.startsWith("tm-") || feature.startsWith("lm-"))
+ feature = feature.replace("-", "_");
+ return feature;
+ }
+
+ /**
+ * Feature functions are instantiated with a line of the form
+ *
+ * <pre>
+ * FEATURE OPTIONS
+ * </pre>
+ *
+ * Weights for features are listed separately.
+ *
+ */
+ private void initializeFeatureFunctions() {
+
+ for (String featureLine : joshuaConfiguration.features) {
+ // line starts with NAME, followed by args
+ // 1. create new class named NAME, pass it config, weights, and the args
+
+ String fields[] = featureLine.split("\\s+");
+ String featureName = fields[0];
-
++
+ try {
-
++
+ Class<?> clas = getFeatureFunctionClass(featureName);
+ Constructor<?> constructor = clas.getConstructor(FeatureVector.class,
+ String[].class, JoshuaConfiguration.class);
+ FeatureFunction feature = (FeatureFunction) constructor.newInstance(weights, fields, joshuaConfiguration);
+ this.featureFunctions.add(feature);
-
++
+ } catch (Exception e) {
- throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
++ throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
+ }
+ }
+
+ for (FeatureFunction feature : featureFunctions) {
+ LOG.info("FEATURE: {}", feature.logString());
+ }
+ }
+
+ /**
+ * Searches a list of predefined paths for classes, and returns the first one found. Meant for
+ * instantiating feature functions.
+ *
+ * @param featureName Class name of the feature to return.
+ * @return the class, found in one of the search paths
+ */
+ private Class<?> getFeatureFunctionClass(String featureName) {
+ Class<?> clas = null;
+
+ String[] packages = { "org.apache.joshua.decoder.ff", "org.apache.joshua.decoder.ff.lm", "org.apache.joshua.decoder.ff.phrase" };
+ for (String path : packages) {
+ try {
+ clas = Class.forName(String.format("%s.%s", path, featureName));
+ break;
+ } catch (ClassNotFoundException e) {
+ try {
+ clas = Class.forName(String.format("%s.%sFF", path, featureName));
+ break;
+ } catch (ClassNotFoundException e2) {
+ // do nothing
+ }
+ }
+ }
+ return clas;
+ }
-
++
+ /**
- * Adds a rule to the custom grammar.
- *
++ * Adds a rule to the custom grammar.
++ *
+ * @param rule the rule to add
+ */
+ public void addCustomRule(Rule rule) {
+ customPhraseTable.addRule(rule);
+ rule.estimateRuleCost(featureFunctions);
+ }
+
+ public Grammar getCustomPhraseTable() {
+ return customPhraseTable;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
index f25590c,0000000..2ac5269
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
@@@ -1,148 -1,0 +1,147 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.net.InetSocketAddress;
+
- import com.sun.net.httpserver.HttpServer;
-
+import org.apache.joshua.decoder.JoshuaConfiguration.SERVER_TYPE;
+import org.apache.joshua.decoder.io.TranslationRequestStream;
++import org.apache.joshua.server.ServerThread;
+import org.apache.joshua.server.TcpServer;
+import org.apache.log4j.Level;
+import org.apache.log4j.LogManager;
- import org.apache.joshua.server.ServerThread;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
++import com.sun.net.httpserver.HttpServer;
++
+/**
+ * Implements decoder initialization, including interaction with <code>JoshuaConfiguration</code>
+ * and <code>DecoderTask</code>.
- *
++ *
+ * @author Zhifei Li, zhifei.work@gmail.com
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @author Lane Schwartz dowobeha@users.sourceforge.net
+ */
+public class JoshuaDecoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(JoshuaDecoder.class);
+
+ // ===============================================================
+ // Main
+ // ===============================================================
+ public static void main(String[] args) throws IOException {
+
+ // default log level
+ LogManager.getRootLogger().setLevel(Level.INFO);
+
+ JoshuaConfiguration joshuaConfiguration = new JoshuaConfiguration();
+ ArgsParser userArgs = new ArgsParser(args,joshuaConfiguration);
+
+ long startTime = System.currentTimeMillis();
+
+ /* Step-0: some sanity checking */
+ joshuaConfiguration.sanityCheck();
+
+ /* Step-1: initialize the decoder, test-set independent */
+ Decoder decoder = new Decoder(joshuaConfiguration);
+
+ LOG.info("Model loading took {} seconds", (System.currentTimeMillis() - startTime) / 1000);
+ LOG.info("Memory used {} MB", ((Runtime.getRuntime().totalMemory()
+ - Runtime.getRuntime().freeMemory()) / 1000000.0));
+
+ /* Step-2: Decoding */
+ // create a server if requested, which will create TranslationRequest objects
+ if (joshuaConfiguration.server_port > 0) {
+ int port = joshuaConfiguration.server_port;
+ if (joshuaConfiguration.server_type == SERVER_TYPE.TCP) {
+ new TcpServer(decoder, port, joshuaConfiguration).start();
+
+ } else if (joshuaConfiguration.server_type == SERVER_TYPE.HTTP) {
+ joshuaConfiguration.use_structured_output = true;
-
++
+ HttpServer server = HttpServer.create(new InetSocketAddress(port), 0);
+ LOG.info("HTTP Server running and listening on port {}.", port);
+ server.createContext("/", new ServerThread(null, decoder, joshuaConfiguration));
+ server.setExecutor(null); // creates a default executor
+ server.start();
+ } else {
+ LOG.error("Unknown server type");
+ System.exit(1);
+ }
+ return;
+ }
-
++
+ // Create a TranslationRequest object, reading from a file if requested, or from STDIN
- InputStream input = (joshuaConfiguration.input_file != null)
++ InputStream input = (joshuaConfiguration.input_file != null)
+ ? new FileInputStream(joshuaConfiguration.input_file)
+ : System.in;
+
+ BufferedReader reader = new BufferedReader(new InputStreamReader(input));
+ TranslationRequestStream fileRequest = new TranslationRequestStream(reader, joshuaConfiguration);
+ TranslationResponseStream translationResponseStream = decoder.decodeAll(fileRequest);
-
++
+ // Create the n-best output stream
+ FileWriter nbest_out = null;
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out = new FileWriter(joshuaConfiguration.n_best_file);
+
+ for (Translation translation: translationResponseStream) {
-
+ /**
+ * We need to munge the feature value outputs in order to be compatible with Moses tuners.
+ * Whereas Joshua writes to STDOUT whatever is specified in the `output-format` parameter,
+ * Moses expects the simple translation on STDOUT and the n-best list in a file with a fixed
+ * format.
+ */
+ if (joshuaConfiguration.moses) {
+ String text = translation.toString().replaceAll("=", "= ");
+ // Write the complete formatted string to STDOUT
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out.write(text);
+
+ // Extract just the translation and output that to STDOUT
+ text = text.substring(0, text.indexOf('\n'));
+ String[] fields = text.split(" \\|\\|\\| ");
+ text = fields[1];
+
+ System.out.println(text);
+
+ } else {
+ System.out.print(translation.toString());
+ }
+ }
+
+ if (joshuaConfiguration.n_best_file != null)
+ nbest_out.close();
+
+ LOG.info("Decoding completed.");
+ LOG.info("Memory used {} MB", ((Runtime.getRuntime().totalMemory()
+ - Runtime.getRuntime().freeMemory()) / 1000000.0));
+
+ /* Step-3: clean up */
+ decoder.cleanUp();
+ LOG.info("Total running time: {} seconds", (System.currentTimeMillis() - startTime) / 1000);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
index 6453bd1,0000000..544e16b
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
+++ b/joshua-core/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
@@@ -1,116 -1,0 +1,114 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.decoder;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiFeatures;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
+import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiWordAlignmentList;
+import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
+
+import java.util.List;
+
+import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.hypergraph.HyperGraph;
+import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
+import org.apache.joshua.decoder.segment_file.Sentence;
- import org.apache.joshua.decoder.segment_file.Token;
- import org.apache.joshua.util.FormatUtils;
+
+/**
+ * This factory provides methods to create StructuredTranslation objects
+ * from either Viterbi derivations or KBest derivations.
- *
++ *
+ * @author fhieber
+ */
+public class StructuredTranslationFactory {
-
++
+ /**
+ * Returns a StructuredTranslation instance from the Viterbi derivation.
- *
++ *
+ * @param sourceSentence the source sentence
+ * @param hypergraph the hypergraph object
+ * @param featureFunctions the list of active feature functions
+ * @return A StructuredTranslation object representing the Viterbi derivation.
+ */
+ public static StructuredTranslation fromViterbiDerivation(
+ final Sentence sourceSentence,
+ final HyperGraph hypergraph,
+ final List<FeatureFunction> featureFunctions) {
+ final long startTime = System.currentTimeMillis();
+ final String translationString = removeSentenceMarkers(getViterbiString(hypergraph));
+ return new StructuredTranslation(
+ sourceSentence,
+ translationString,
+ extractTranslationTokens(translationString),
+ extractTranslationScore(hypergraph),
+ getViterbiWordAlignmentList(hypergraph),
+ getViterbiFeatures(hypergraph, featureFunctions, sourceSentence).toStringMap(),
+ (System.currentTimeMillis() - startTime) / 1000.0f);
+ }
-
++
+ /**
+ * Returns a StructuredTranslation from an empty decoder output
+ * @param sourceSentence the source sentence
+ * @return a StructuredTranslation object
+ */
+ public static StructuredTranslation fromEmptyOutput(final Sentence sourceSentence) {
+ return new StructuredTranslation(
+ sourceSentence, "", emptyList(), 0, emptyList(), emptyMap(), 0f);
+ }
-
++
+ /**
- * Returns a StructuredTranslation instance from a KBest DerivationState.
++ * Returns a StructuredTranslation instance from a KBest DerivationState.
+ * @param sourceSentence Sentence object representing the source.
+ * @param derivationState the KBest DerivationState.
+ * @return A StructuredTranslation object representing the derivation encoded by derivationState.
+ */
+ public static StructuredTranslation fromKBestDerivation(
+ final Sentence sourceSentence,
+ final DerivationState derivationState) {
+ final long startTime = System.currentTimeMillis();
+ final String translationString = removeSentenceMarkers(derivationState.getHypothesis());
+ return new StructuredTranslation(
+ sourceSentence,
+ translationString,
+ extractTranslationTokens(translationString),
+ derivationState.getModelCost(),
+ derivationState.getWordAlignmentList(),
+ derivationState.getFeatures().toStringMap(),
+ (System.currentTimeMillis() - startTime) / 1000.0f);
+ }
-
++
+ private static float extractTranslationScore(final HyperGraph hypergraph) {
+ if (hypergraph == null) {
+ return 0;
+ } else {
+ return hypergraph.goalNode.getScore();
+ }
+ }
-
++
+ private static List<String> extractTranslationTokens(final String translationString) {
+ if (translationString.isEmpty()) {
+ return emptyList();
+ } else {
+ return asList(translationString.split("\\s+"));
+ }
+ }
+}
[04/17] incubator-joshua git commit: Fix a number of issues: - Reader
now implements autocloseable - Close various leaks from LineReader -
LineReader no longer implements custom finalize(). Resources should be
explicitly closed when no longer needed. T
Posted by mj...@apache.org.
Fix a number of issues:
- Reader now implements autocloseable
- Close various leaks from LineReader
- LineReader no longer implements custom finalize(). Resources should be
explicitly closed when no longer needed. The compiler helps with this.
- Start refactoring copy/pasted code into a new type: ExistingUTF8EncodedTextFile.
There is so much use of text files that this should really have
its own type.
- Fix warnings about unused fields, unused methods
- Delete some old/legacy/unused classes
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/840eb4ce
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/840eb4ce
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/840eb4ce
Branch: refs/heads/7
Commit: 840eb4ce4854d3f5e4f792ce4d909d6276bfcb08
Parents: 9385cf7
Author: max thomas <ma...@maxthomas.io>
Authored: Thu Aug 18 15:04:52 2016 -0400
Committer: max thomas <ma...@maxthomas.io>
Committed: Tue Aug 30 11:40:22 2016 -0400
----------------------------------------------------------------------
.../org/apache/joshua/adagrad/AdaGradCore.java | 218 +------
.../org/apache/joshua/adagrad/Optimizer.java | 52 +-
.../joshua/corpus/syntax/ArraySyntaxTree.java | 20 +-
.../org/apache/joshua/decoder/ArgsParser.java | 24 +-
.../java/org/apache/joshua/decoder/Decoder.java | 41 +-
.../apache/joshua/decoder/JoshuaDecoder.java | 17 +-
.../decoder/StructuredTranslationFactory.java | 18 +-
.../org/apache/joshua/decoder/Translation.java | 35 +-
.../joshua/decoder/chart_parser/DotChart.java | 78 +--
.../apache/joshua/decoder/ff/TargetBigram.java | 5 +-
.../joshua/decoder/ff/fragmentlm/Tree.java | 104 ++--
.../org/apache/joshua/decoder/ff/lm/KenLM.java | 20 +-
.../joshua/decoder/ff/lm/buildin_lm/TrieLM.java | 168 ++----
.../joshua/decoder/ff/tm/CreateGlueGrammar.java | 42 +-
.../decoder/ff/tm/packed/PackedGrammar.java | 84 ++-
.../java/org/apache/joshua/lattice/Lattice.java | 54 +-
.../java/org/apache/joshua/metrics/CHRF.java | 97 ++--
.../java/org/apache/joshua/metrics/SARI.java | 117 ++--
.../java/org/apache/joshua/mira/MIRACore.java | 231 +-------
.../java/org/apache/joshua/mira/Optimizer.java | 26 +-
.../java/org/apache/joshua/pro/PROCore.java | 224 +-------
.../org/apache/joshua/tools/GrammarPacker.java | 87 +--
.../org/apache/joshua/tools/LabelPhrases.java | 84 +--
.../org/apache/joshua/tools/TestSetFilter.java | 83 +--
.../java/org/apache/joshua/util/BotMap.java | 94 ---
.../java/org/apache/joshua/util/Constants.java | 6 +-
.../org/apache/joshua/util/FileUtility.java | 261 +--------
.../org/apache/joshua/util/IntegerPair.java | 36 --
.../java/org/apache/joshua/util/ListUtil.java | 51 --
src/main/java/org/apache/joshua/util/Lists.java | 567 -------------------
.../org/apache/joshua/util/NullIterator.java | 65 ---
.../org/apache/joshua/util/QuietFormatter.java | 36 --
.../org/apache/joshua/util/ReverseOrder.java | 39 --
.../org/apache/joshua/util/SampledList.java | 69 ---
.../org/apache/joshua/util/SocketUtility.java | 144 -----
.../apache/joshua/util/encoding/Analyzer.java | 85 +--
.../util/encoding/FeatureTypeAnalyzer.java | 43 +-
.../util/io/ExistingUTF8EncodedTextFile.java | 77 +++
.../apache/joshua/util/io/IndexedReader.java | 13 +-
.../org/apache/joshua/util/io/LineReader.java | 123 ++--
.../org/apache/joshua/util/io/NullReader.java | 63 ---
.../java/org/apache/joshua/util/io/Reader.java | 11 +-
.../joshua/util/quantization/Quantizer.java | 48 +-
.../quantization/QuantizerConfiguration.java | 167 +++---
.../util/quantization/StatelessQuantizer.java | 26 +-
.../java/org/apache/joshua/zmert/MertCore.java | 205 +------
.../org/apache/joshua/packed/Benchmark.java | 28 +-
.../system/MultithreadedTranslationTests.java | 12 +-
48 files changed, 947 insertions(+), 3251 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java b/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
index 9dc81a4..53bb342 100755
--- a/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
+++ b/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
@@ -50,7 +50,7 @@ import org.apache.joshua.decoder.Decoder;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.metrics.EvaluationMetric;
import org.apache.joshua.util.StreamGobbler;
-
+import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -67,17 +67,13 @@ public class AdaGradCore {
private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
private final JoshuaConfiguration joshuaConfiguration;
- private final Runtime myRuntime = Runtime.getRuntime();
private TreeSet<Integer>[] indicesOfInterest_all;
- private int progress;
-
private int verbosity; // anything of priority <= verbosity will be printed
// (lower value for priority means more important)
private Random randGen;
- private int generatedRands;
private int numSentences;
// number of sentences in the dev set
@@ -235,7 +231,6 @@ public class AdaGradCore {
private boolean usePseudoBleu = true; // need to use pseudo corpus to compute bleu?
private boolean returnBest = true; // return the best weight during tuning
private boolean needScale = true; // need scaling?
- private String trainingMode;
private int oraSelectMode = 1;
private int predSelectMode = 1;
private int adagradIter = 1;
@@ -265,28 +260,27 @@ public class AdaGradCore {
this.joshuaConfiguration = joshuaConfiguration;
}
- public AdaGradCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
+ public AdaGradCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(args);
initialize(0);
}
- public AdaGradCore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
+ public AdaGradCore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
this.joshuaConfiguration = joshuaConfiguration;
EvaluationMetric.set_knownMetrics();
processArgsArray(cfgFileToArgsArray(configFileName));
initialize(0);
}
- private void initialize(int randsToSkip) {
+ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
randGen = new Random(seed);
for (int r = 1; r <= randsToSkip; ++r) {
randGen.nextDouble();
}
- generatedRands = randsToSkip;
if (randsToSkip == 0) {
println("----------------------------------------------------", 1);
@@ -300,7 +294,7 @@ public class AdaGradCore {
// count the total num of sentences to be decoded, reffilename is the combined reference file
// name(auto generated)
- numSentences = countLines(refFileName) / refsPerSen;
+ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
// ??
processDocInfo();
@@ -313,7 +307,7 @@ public class AdaGradCore {
set_docSubsetInfo(docSubsetInfo);
// count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
+ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
numParamsOld = numParams;
// read parameter config file
@@ -864,7 +858,6 @@ public class AdaGradCore {
// iterations if the user specifies a value for prevMERTIterations
// that causes MERT to skip candidates from early iterations.
- double[] currFeatVal = new double[1 + numParams];
String[] featVal_str;
int totalCandidateCount = 0;
@@ -914,13 +907,6 @@ public class AdaGradCore {
// extract feature value
featVal_str = feats_str.split("\\s+");
- if (feats_str.indexOf('=') != -1) {
- for (String featurePair : featVal_str) {
- String[] pair = featurePair.split("=");
- String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
- }
- }
existingCandStats.put(sents_str, stats_str);
candCount[i] += 1;
newCandidatesAdded[it] += 1;
@@ -1114,7 +1100,6 @@ public class AdaGradCore {
for (String featurePair : featVal_str) {
String[] pair = featurePair.split("=");
String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
int featId = Vocabulary.id(name);
// need to identify newly fired feats here
@@ -1610,8 +1595,6 @@ public class AdaGradCore {
BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
PrintWriter outFile = new PrintWriter(cfgFileName);
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
int origFeatNum = 0; // feat num in the template file
String line = inFile.readLine();
@@ -1813,7 +1796,7 @@ public class AdaGradCore {
// belongs to,
// and its order in that document. (can also use '-' instead of '_')
- int docInfoSize = countNonEmptyLines(docInfoFileName);
+ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
if (docInfoSize < numSentences) { // format #1 or #2
numDocuments = docInfoSize;
@@ -2736,10 +2719,10 @@ public class AdaGradCore {
} else {
nextIndex = 1;
}
- int lineCount = countLines(prefix + nextIndex);
+ int lineCount = new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines();
for (int r = 0; r < numFiles; ++r) {
- if (countLines(prefix + nextIndex) != lineCount) {
+ if (new ExistingUTF8EncodedTextFile(prefix + nextIndex).getNumberOfLines() != lineCount) {
String msg = "Line count mismatch in " + (prefix + nextIndex) + ".";
throw new RuntimeException(msg);
}
@@ -2901,109 +2884,11 @@ public class AdaGradCore {
return str;
}
- private int countLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
- private int countNonEmptyLines(String fileName) {
- int count = 0;
-
- try {
- BufferedReader inFile = new BufferedReader(new FileReader(fileName));
-
- String line;
- do {
- line = inFile.readLine();
- if (line != null && line.length() > 0)
- ++count;
- } while (line != null);
-
- inFile.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- return count;
- }
-
private String fullPath(String dir, String fileName) {
File dummyFile = new File(dir, fileName);
return dummyFile.getAbsolutePath();
}
- @SuppressWarnings("unused")
- private void cleanupMemory() {
- cleanupMemory(100, false);
- }
-
- @SuppressWarnings("unused")
- private void cleanupMemorySilently() {
- cleanupMemory(100, true);
- }
-
- @SuppressWarnings("static-access")
- private void cleanupMemory(int reps, boolean silent) {
- int bytesPerMB = 1024 * 1024;
-
- long totalMemBefore = myRuntime.totalMemory();
- long freeMemBefore = myRuntime.freeMemory();
- long usedMemBefore = totalMemBefore - freeMemBefore;
-
- long usedCurr = usedMemBefore;
- long usedPrev = usedCurr;
-
- // perform garbage collection repeatedly, until there is no decrease in
- // the amount of used memory
- for (int i = 1; i <= reps; ++i) {
- myRuntime.runFinalization();
- myRuntime.gc();
- (Thread.currentThread()).yield();
-
- usedPrev = usedCurr;
- usedCurr = myRuntime.totalMemory() - myRuntime.freeMemory();
-
- if (usedCurr == usedPrev)
- break;
- }
-
- if (!silent) {
- long totalMemAfter = myRuntime.totalMemory();
- long freeMemAfter = myRuntime.freeMemory();
- long usedMemAfter = totalMemAfter - freeMemAfter;
-
- println("GC: d_used = " + ((usedMemAfter - usedMemBefore) / bytesPerMB) + " MB "
- + "(d_tot = " + ((totalMemAfter - totalMemBefore) / bytesPerMB) + " MB).", 2);
- }
- }
-
- @SuppressWarnings("unused")
- private void printMemoryUsage() {
- int bytesPerMB = 1024 * 1024;
- long totalMem = myRuntime.totalMemory();
- long freeMem = myRuntime.freeMemory();
- long usedMem = totalMem - freeMem;
-
- println("Allocated memory: " + (totalMem / bytesPerMB) + " MB " + "(of which "
- + (usedMem / bytesPerMB) + " MB is being used).", 2);
- }
-
private void println(Object obj, int priority) {
if (priority <= verbosity)
println(obj);
@@ -3022,20 +2907,12 @@ public class AdaGradCore {
System.out.print(obj);
}
- @SuppressWarnings("unused")
- private void showProgress() {
- ++progress;
- if (progress % 100000 == 0)
- print(".", 2);
- }
-
private ArrayList<Double> randomLambda() {
ArrayList<Double> retLambda = new ArrayList<>(1 + numParams);
for (int c = 1; c <= numParams; ++c) {
if (isOptimizable[c]) {
double randVal = randGen.nextDouble(); // number in [0.0,1.0]
- ++generatedRands;
randVal = randVal * (maxRandValue[c] - minRandValue[c]); // number in [0.0,max-min]
randVal = minRandValue[c] + randVal; // number in [min,max]
retLambda.set(c, randVal);
@@ -3046,81 +2923,4 @@ public class AdaGradCore {
return retLambda;
}
-
- private double[] randomPerturbation(double[] origLambda, int i, double method, double param,
- double mult) {
- double sigma = 0.0;
- if (method == 1) {
- sigma = 1.0 / Math.pow(i, param);
- } else if (method == 2) {
- sigma = Math.exp(-param * i);
- } else if (method == 3) {
- sigma = Math.max(0.0, 1.0 - (i / param));
- }
-
- sigma = mult * sigma;
-
- double[] retLambda = new double[1 + numParams];
-
- for (int c = 1; c <= numParams; ++c) {
- if (isOptimizable[c]) {
- double randVal = 2 * randGen.nextDouble() - 1.0; // number in [-1.0,1.0]
- ++generatedRands;
- randVal = randVal * sigma; // number in [-sigma,sigma]
- randVal = randVal * origLambda[c]; // number in [-sigma*orig[c],sigma*orig[c]]
- randVal = randVal + origLambda[c]; // number in
- // [orig[c]-sigma*orig[c],orig[c]+sigma*orig[c]]
- // = [orig[c]*(1-sigma),orig[c]*(1+sigma)]
- retLambda[c] = randVal;
- } else {
- retLambda[c] = origLambda[c];
- }
- }
-
- return retLambda;
- }
-
- @SuppressWarnings("unused")
- private HashSet<Integer> indicesToDiscard(double[] slope, double[] offset) {
- // some lines can be eliminated: the ones that have a lower offset
- // than some other line with the same slope.
- // That is, for any k1 and k2:
- // if slope[k1] = slope[k2] and offset[k1] > offset[k2],
- // then k2 can be eliminated.
- // (This is actually important to do as it eliminates a bug.)
- // print("discarding: ",4);
-
- int numCandidates = slope.length;
- HashSet<Integer> discardedIndices = new HashSet<>();
- HashMap<Double, Integer> indicesOfSlopes = new HashMap<>();
- // maps slope to index of best candidate that has that slope.
- // ("best" as in the one with the highest offset)
-
- for (int k1 = 0; k1 < numCandidates; ++k1) {
- double currSlope = slope[k1];
- if (!indicesOfSlopes.containsKey(currSlope)) {
- indicesOfSlopes.put(currSlope, k1);
- } else {
- int existingIndex = indicesOfSlopes.get(currSlope);
- if (offset[existingIndex] > offset[k1]) {
- discardedIndices.add(k1);
- // print(k1 + " ",4);
- } else if (offset[k1] > offset[existingIndex]) {
- indicesOfSlopes.put(currSlope, k1);
- discardedIndices.add(existingIndex);
- // print(existingIndex + " ",4);
- }
- }
- }
-
- // old way of doing it; takes quadratic time (vs. linear time above)
- /*
- * for (int k1 = 0; k1 < numCandidates; ++k1) { for (int k2 = 0; k2 < numCandidates; ++k2) { if
- * (k1 != k2 && slope[k1] == slope[k2] && offset[k1] > offset[k2]) { discardedIndices.add(k2);
- * // print(k2 + " ",4); } } }
- */
-
- // println("",4);
- return discardedIndices;
- } // indicesToDiscard(double[] slope, double[] offset)
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/adagrad/Optimizer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/adagrad/Optimizer.java b/src/main/java/org/apache/joshua/adagrad/Optimizer.java
index 16c25cd..6ad85a8 100755
--- a/src/main/java/org/apache/joshua/adagrad/Optimizer.java
+++ b/src/main/java/org/apache/joshua/adagrad/Optimizer.java
@@ -18,14 +18,13 @@
*/
package org.apache.joshua.adagrad;
-import java.util.Collections;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Vector;
-import java.lang.Math;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.metrics.EvaluationMetric;
@@ -37,7 +36,7 @@ public class Optimizer {
output = _output; // (not used for now)
isOptimizable = _isOptimizable;
initialLambda = _initialLambda; // initial weights array
- paramDim = initialLambda.length - 1;
+ paramDim = initialLambda.length - 1;
initialLambda = _initialLambda;
feat_hash = _feat_hash; // feature hash table
stats_hash = _stats_hash; // suff. stats hash table
@@ -57,7 +56,7 @@ public class Optimizer {
System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
if(needShuffle)
Collections.shuffle(sents);
-
+
double oraMetric, oraScore, predMetric, predScore;
double[] oraPredScore = new double[4];
double loss = 0;
@@ -70,10 +69,9 @@ public class Optimizer {
String[] vecOraFeat;
String[] vecPredFeat;
String[] featInfo;
- int thisBatchSize = 0;
int numBatch = 0;
int numUpdate = 0;
- Iterator it;
+ Iterator<Integer> it;
Integer diffFeatId;
//update weights
@@ -91,14 +89,13 @@ public class Optimizer {
HashMap<Integer, Double> H = new HashMap<>();
while( sentCount < sentNum ) {
loss = 0;
- thisBatchSize = batchSize;
++numBatch;
HashMap<Integer, Double> featDiff = new HashMap<>();
for(int b = 0; b < batchSize; ++b ) {
//find out oracle and prediction
s = sents.get(sentCount);
findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
-
+
//the model scores here are already scaled in findOraPred
oraMetric = oraPredScore[0];
oraScore = oraPredScore[1];
@@ -106,18 +103,18 @@ public class Optimizer {
predScore = oraPredScore[3];
oraFeat = oraPredFeat[0];
predFeat = oraPredFeat[1];
-
+
//update the scale
if(needScale) { //otherwise featscale remains 1.0
sumMetricScore += Math.abs(oraMetric + predMetric);
//restore the original model score
sumModelScore += Math.abs(oraScore + predScore) / featScale;
-
+
if(sumModelScore/sumMetricScore > scoreRatio)
featScale = sumMetricScore/sumModelScore;
}
// processedSent++;
-
+
vecOraFeat = oraFeat.split("\\s+");
vecPredFeat = predFeat.split("\\s+");
@@ -169,13 +166,12 @@ public class Optimizer {
//remember the model scores here are already scaled
double singleLoss = evalMetric.getToBeMinimized() ?
- (predMetric-oraMetric) - (oraScore-predScore)/featScale:
+ (predMetric-oraMetric) - (oraScore-predScore)/featScale:
(oraMetric-predMetric) - (oraScore-predScore)/featScale;
if(singleLoss > 0)
loss += singleLoss;
++sentCount;
if( sentCount >= sentNum ) {
- thisBatchSize = b + 1;
break;
}
} //for(int b : batchSize)
@@ -189,7 +185,7 @@ public class Optimizer {
Set<Integer> diffFeatSet = featDiff.keySet();
it = diffFeatSet.iterator();
while(it.hasNext()) { //note these are all non-zero gradients!
- diffFeatId = (Integer)it.next();
+ diffFeatId = it.next();
diffFeatVal = -1.0 * featDiff.get(diffFeatId); //gradient
if( regularization > 0 ) {
lastUpdateTime =
@@ -297,7 +293,7 @@ public class Optimizer {
finalLambda[i] =
Math.signum(oldVal) * clip( Math.abs(oldVal) - lam * eta * (numUpdate - lastUpdate.get(i)) / Hii );
else if( regularization == 2 ) {
- finalLambda[i] =
+ finalLambda[i] =
Math.pow( Hii/(lam+Hii), (numUpdate - lastUpdate.get(i)) ) * oldVal;
if(needAvg) { //fill the gap due to lazy update
double prevLambdaCopy = finalLambda[i];
@@ -338,7 +334,7 @@ public class Optimizer {
// numParamToPrint = paramDim > 10 ? 10 : paramDim; // how many parameters
// // to print
// result = paramDim > 10 ? "Final lambda (first 10): {" : "Final lambda: {";
-
+
// for (int i = 1; i <= numParamToPrint; ++i)
// result += String.format("%.4f", finalLambda[i]) + " ";
@@ -412,7 +408,7 @@ public class Optimizer {
return evalMetric.score(corpusStatsVal);
}
-
+
private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat, double[] lambda, double featScale)
{
double oraMetric=0, oraScore=0, predMetric=0, predScore=0;
@@ -424,11 +420,11 @@ public class Optimizer {
String oraCand = ""; //only used when BLEU/TER-BLEU is used as metric
String[] featStr;
String[] featInfo;
-
+
int actualFeatId;
double bestOraScore;
double worstPredScore;
-
+
if(oraSelectMode==1)
bestOraScore = NegInf; //larger score will be selected
else {
@@ -437,7 +433,7 @@ public class Optimizer {
else
bestOraScore = NegInf;
}
-
+
if(predSelectMode==1 || predSelectMode==2)
worstPredScore = NegInf; //larger score will be selected
else {
@@ -548,14 +544,14 @@ public class Optimizer {
}
}
}
-
+
oraPredScore[0] = oraMetric;
oraPredScore[1] = oraScore;
oraPredScore[2] = predMetric;
oraPredScore[3] = predScore;
oraPredFeat[0] = oraFeat;
oraPredFeat[1] = predFeat;
-
+
//update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
if(evalMetric.get_metricName().equals("BLEU") && usePseudoBleu ) {
String statString;
@@ -566,7 +562,7 @@ public class Optimizer {
for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j]);
}
-
+
if(evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu ) {
String statString;
String[] statVal_str;
@@ -577,7 +573,7 @@ public class Optimizer {
bleuHistory[sentId][j] = R*bleuHistory[sentId][j]+Integer.parseInt(statVal_str[j+2]); //the first 2 stats are TER stats
}
}
-
+
// compute *sentence-level* metric score for cand
private double computeSentMetric(int sentId, String cand) {
String statString;
@@ -667,7 +663,7 @@ public class Optimizer {
{
return featScale;
}
-
+
public static void initBleuHistory(int sentNum, int statCount)
{
bleuHistory = new double[sentNum][statCount];
@@ -682,7 +678,7 @@ public class Optimizer {
{
return finalMetricScore;
}
-
+
private final Vector<String> output;
private double[] initialLambda;
private final double[] finalLambda;
@@ -706,11 +702,11 @@ public class Optimizer {
//updates in each epoch if necessary
public static double eta;
public static double lam;
- public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static double R; //corpus decay(used only when pseudo corpus is used to compute BLEU)
public static EvaluationMetric evalMetric;
public static double[] normalizationOptions;
public static double[][] bleuHistory;
-
+
private final static double NegInf = (-1.0 / 0.0);
private final static double PosInf = (+1.0 / 0.0);
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java b/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
index 10efdc6..27303ec 100644
--- a/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
+++ b/src/main/java/org/apache/joshua/corpus/syntax/ArraySyntaxTree.java
@@ -71,6 +71,7 @@ public class ArraySyntaxTree implements SyntaxTree, Externalizable {
* Returns a collection of single-non-terminal labels that exactly cover the specified span in the
* lattice.
*/
+ @Override
public Collection<Integer> getConstituentLabels(int from, int to) {
Collection<Integer> labels = new HashSet<>();
int span_length = to - from;
@@ -167,6 +168,7 @@ public class ArraySyntaxTree implements SyntaxTree, Externalizable {
* in the lattice. The number of non-terminals concatenated is limited by MAX_CONCATENATIONS and
* the total number of labels returned is bounded by MAX_LABELS.
*/
+ @Override
public Collection<Integer> getConcatenatedLabels(int from, int to) {
Collection<Integer> labels = new HashSet<>();
@@ -216,6 +218,7 @@ public class ArraySyntaxTree implements SyntaxTree, Externalizable {
}
// TODO: can pre-comupute all that in top-down fashion.
+ @Override
public Collection<Integer> getCcgLabels(int from, int to) {
Collection<Integer> labels = new HashSet<>();
@@ -296,10 +299,12 @@ public class ArraySyntaxTree implements SyntaxTree, Externalizable {
return span;
}
+ @Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// TODO Auto-generated method stub
}
+ @Override
public void writeExternal(ObjectOutput out) throws IOException {
// TODO Auto-generated method stub
}
@@ -310,18 +315,15 @@ public class ArraySyntaxTree implements SyntaxTree, Externalizable {
* @throws IOException if the file does not exist
*/
public void readExternalText(String file_name) throws IOException {
- LineReader reader = new LineReader(file_name);
- initialize();
- for (String line : reader) {
- if (line.trim().equals("")) continue;
- appendFromPennFormat(line);
+ try (LineReader reader = new LineReader(file_name);) {
+ initialize();
+ for (String line : reader) {
+ if (line.trim().equals("")) continue;
+ appendFromPennFormat(line);
+ }
}
}
- public void writeExternalText(String file_name) throws IOException {
- // TODO Auto-generated method stub
- }
-
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ArgsParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ArgsParser.java b/src/main/java/org/apache/joshua/decoder/ArgsParser.java
index 26ed674..97baa27 100644
--- a/src/main/java/org/apache/joshua/decoder/ArgsParser.java
+++ b/src/main/java/org/apache/joshua/decoder/ArgsParser.java
@@ -29,7 +29,7 @@ import org.slf4j.LoggerFactory;
/**
* @author orluke
- *
+ *
*/
public class ArgsParser {
@@ -40,7 +40,7 @@ public class ArgsParser {
/**
* Parse the arguments passed from the command line when the JoshuaDecoder application was
* executed from the command line.
- *
+ *
* @param args string array of input arguments
* @param config the {@link org.apache.joshua.decoder.JoshuaConfiguration}
* @throws IOException if there is an error wit the input arguments
@@ -49,8 +49,8 @@ public class ArgsParser {
/*
* Look for a verbose flag, -v.
- *
- * Look for an argument to the "-config" flag to find the config file, if any.
+ *
+ * Look for an argument to the "-config" flag to find the config file, if any.
*/
if (args.length >= 1) {
// Search for a verbose flag
@@ -59,15 +59,15 @@ public class ArgsParser {
Decoder.VERBOSE = Integer.parseInt(args[i + 1].trim());
config.setVerbosity(Decoder.VERBOSE);
}
-
- if (args[i].equals("-version")) {
- LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));
- reader.readLine();
- String version = reader.readLine().split("\\s+")[2];
- System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
- System.out.println("joshua.incubator.apache.org");
- System.exit(0);
+ if (args[i].equals("-version")) {
+ try (LineReader reader = new LineReader(String.format("%s/VERSION", System.getenv("JOSHUA")));) {
+ reader.readLine();
+ String version = reader.readLine().split("\\s+")[2];
+ System.out.println(String.format("The Apache Joshua machine translator, version %s", version));
+ System.out.println("joshua.incubator.apache.org");
+ System.exit(0);
+ }
} else if (args[i].equals("-license")) {
try {
Files.readAllLines(Paths.get(String.format("%s/../LICENSE",
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/Decoder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/Decoder.java b/src/main/java/org/apache/joshua/decoder/Decoder.java
index 682e290..c15898c 100644
--- a/src/main/java/org/apache/joshua/decoder/Decoder.java
+++ b/src/main/java/org/apache/joshua/decoder/Decoder.java
@@ -23,8 +23,8 @@ import static org.apache.joshua.decoder.ff.tm.OwnerMap.getOwner;
import java.io.BufferedWriter;
import java.io.File;
-import java.io.IOException;
import java.io.FileNotFoundException;
+import java.io.IOException;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.HashMap;
@@ -36,11 +36,9 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
-import com.google.common.base.Strings;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.joshua.corpus.Vocabulary;
-import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.FeatureFunction;
+import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.PhraseModel;
import org.apache.joshua.decoder.ff.StatefulFF;
import org.apache.joshua.decoder.ff.lm.LanguageModelFF;
@@ -61,6 +59,9 @@ import org.apache.joshua.util.io.LineReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.google.common.base.Strings;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+
/**
* This class handles decoder initialization and the complication introduced by multithreading.
*
@@ -149,7 +150,7 @@ public class Decoder {
this.joshuaConfiguration = joshuaConfiguration;
this.grammars = new ArrayList<>();
this.customPhraseTable = null;
-
+
resetGlobalState();
}
@@ -313,7 +314,7 @@ public class Decoder {
// ===============================================================
/**
- * Moses requires the pattern .*_.* for sparse features, and prohibits underscores in dense features.
+ * Moses requires the pattern .*_.* for sparse features, and prohibits underscores in dense features.
* This conforms to that pattern. We assume non-conforming dense features start with tm_ or lm_,
* and the only sparse feature that needs converting is OOVPenalty.
*
@@ -344,8 +345,8 @@ public class Decoder {
* in the Joshua config file. Config file values take precedent.
*/
this.readWeights(joshuaConfiguration.weights_file);
-
-
+
+
/* Add command-line-passed weights to the weights array for processing below */
if (!Strings.isNullOrEmpty(joshuaConfiguration.weight_overwrite)) {
String[] tokens = joshuaConfiguration.weight_overwrite.split("\\s+");
@@ -485,14 +486,14 @@ public class Decoder {
glueGrammar.addGlueRules(featureFunctions);
this.grammars.add(glueGrammar);
}
-
+
/* Add the grammar for custom entries */
if (joshuaConfiguration.search_algorithm.equals("stack"))
this.customPhraseTable = new PhraseTable(null, "custom", "phrase", joshuaConfiguration);
else
this.customPhraseTable = new MemoryBasedBatchGrammar("custom", joshuaConfiguration, 20);
this.grammars.add(this.customPhraseTable);
-
+
/* Create an epsilon-deleting grammar */
if (joshuaConfiguration.lattice_decoding) {
LOG.info("Creating an epsilon-deleting grammar");
@@ -553,7 +554,7 @@ public class Decoder {
/*
* This function reads the weights for the model. Feature names and their weights are listed one
* per line in the following format:
- *
+ *
* FEATURE_NAME WEIGHT
*/
private void readWeights(String fileName) {
@@ -562,9 +563,7 @@ public class Decoder {
if (fileName.equals(""))
return;
- try {
- LineReader lineReader = new LineReader(fileName);
-
+ try (LineReader lineReader = new LineReader(fileName);) {
for (String line : lineReader) {
line = line.replaceAll("\\s+", " ");
@@ -619,17 +618,17 @@ public class Decoder {
String fields[] = featureLine.split("\\s+");
String featureName = fields[0];
-
+
try {
-
+
Class<?> clas = getFeatureFunctionClass(featureName);
Constructor<?> constructor = clas.getConstructor(FeatureVector.class,
String[].class, JoshuaConfiguration.class);
FeatureFunction feature = (FeatureFunction) constructor.newInstance(weights, fields, joshuaConfiguration);
this.featureFunctions.add(feature);
-
+
} catch (Exception e) {
- throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
+ throw new RuntimeException(String.format("Unable to instantiate feature function '%s'!", featureLine), e);
}
}
@@ -667,10 +666,10 @@ public class Decoder {
}
return clas;
}
-
+
/**
- * Adds a rule to the custom grammar.
- *
+ * Adds a rule to the custom grammar.
+ *
* @param rule the rule to add
*/
public void addCustomRule(Rule rule) {
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java b/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
index d10de8c..de5ab36 100644
--- a/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
+++ b/src/main/java/org/apache/joshua/decoder/JoshuaDecoder.java
@@ -26,21 +26,21 @@ import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.InetSocketAddress;
-import com.sun.net.httpserver.HttpServer;
-
import org.apache.joshua.decoder.JoshuaConfiguration.SERVER_TYPE;
import org.apache.joshua.decoder.io.TranslationRequestStream;
+import org.apache.joshua.server.ServerThread;
import org.apache.joshua.server.TcpServer;
import org.apache.log4j.Level;
import org.apache.log4j.LogManager;
-import org.apache.joshua.server.ServerThread;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.sun.net.httpserver.HttpServer;
+
/**
* Implements decoder initialization, including interaction with <code>JoshuaConfiguration</code>
* and <code>DecoderTask</code>.
- *
+ *
* @author Zhifei Li, zhifei.work@gmail.com
* @author wren ng thornton wren@users.sourceforge.net
* @author Lane Schwartz dowobeha@users.sourceforge.net
@@ -81,7 +81,7 @@ public class JoshuaDecoder {
} else if (joshuaConfiguration.server_type == SERVER_TYPE.HTTP) {
joshuaConfiguration.use_structured_output = true;
-
+
HttpServer server = HttpServer.create(new InetSocketAddress(port), 0);
LOG.info("HTTP Server running and listening on port {}.", port);
server.createContext("/", new ServerThread(null, decoder, joshuaConfiguration));
@@ -93,23 +93,22 @@ public class JoshuaDecoder {
}
return;
}
-
+
// Create a TranslationRequest object, reading from a file if requested, or from STDIN
- InputStream input = (joshuaConfiguration.input_file != null)
+ InputStream input = (joshuaConfiguration.input_file != null)
? new FileInputStream(joshuaConfiguration.input_file)
: System.in;
BufferedReader reader = new BufferedReader(new InputStreamReader(input));
TranslationRequestStream fileRequest = new TranslationRequestStream(reader, joshuaConfiguration);
TranslationResponseStream translationResponseStream = decoder.decodeAll(fileRequest);
-
+
// Create the n-best output stream
FileWriter nbest_out = null;
if (joshuaConfiguration.n_best_file != null)
nbest_out = new FileWriter(joshuaConfiguration.n_best_file);
for (Translation translation: translationResponseStream) {
-
/**
* We need to munge the feature value outputs in order to be compatible with Moses tuners.
* Whereas Joshua writes to STDOUT whatever is specified in the `output-format` parameter,
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java b/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
index 1ba19f0..60f0efe 100644
--- a/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
+++ b/src/main/java/org/apache/joshua/decoder/StructuredTranslationFactory.java
@@ -32,20 +32,18 @@ import org.apache.joshua.decoder.ff.FeatureFunction;
import org.apache.joshua.decoder.hypergraph.HyperGraph;
import org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationState;
import org.apache.joshua.decoder.segment_file.Sentence;
-import org.apache.joshua.decoder.segment_file.Token;
-import org.apache.joshua.util.FormatUtils;
/**
* This factory provides methods to create StructuredTranslation objects
* from either Viterbi derivations or KBest derivations.
- *
+ *
* @author fhieber
*/
public class StructuredTranslationFactory {
-
+
/**
* Returns a StructuredTranslation instance from the Viterbi derivation.
- *
+ *
* @param sourceSentence the source sentence
* @param hypergraph the hypergraph object
* @param featureFunctions the list of active feature functions
@@ -66,7 +64,7 @@ public class StructuredTranslationFactory {
getViterbiFeatures(hypergraph, featureFunctions, sourceSentence).getMap(),
(System.currentTimeMillis() - startTime) / 1000.0f);
}
-
+
/**
* Returns a StructuredTranslation from an empty decoder output
* @param sourceSentence the source sentence
@@ -76,9 +74,9 @@ public class StructuredTranslationFactory {
return new StructuredTranslation(
sourceSentence, "", emptyList(), 0, emptyList(), emptyMap(), 0f);
}
-
+
/**
- * Returns a StructuredTranslation instance from a KBest DerivationState.
+ * Returns a StructuredTranslation instance from a KBest DerivationState.
* @param sourceSentence Sentence object representing the source.
* @param derivationState the KBest DerivationState.
* @return A StructuredTranslation object representing the derivation encoded by derivationState.
@@ -97,7 +95,7 @@ public class StructuredTranslationFactory {
derivationState.getFeatures().getMap(),
(System.currentTimeMillis() - startTime) / 1000.0f);
}
-
+
private static float extractTranslationScore(final HyperGraph hypergraph) {
if (hypergraph == null) {
return 0;
@@ -105,7 +103,7 @@ public class StructuredTranslationFactory {
return hypergraph.goalNode.getScore();
}
}
-
+
private static List<String> extractTranslationTokens(final String translationString) {
if (translationString.isEmpty()) {
return emptyList();
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/Translation.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/Translation.java b/src/main/java/org/apache/joshua/decoder/Translation.java
index 142ff05..ade9b22 100644
--- a/src/main/java/org/apache/joshua/decoder/Translation.java
+++ b/src/main/java/org/apache/joshua/decoder/Translation.java
@@ -18,12 +18,11 @@
*/
package org.apache.joshua.decoder;
+import static org.apache.joshua.decoder.StructuredTranslationFactory.fromViterbiDerivation;
import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiFeatures;
import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiString;
import static org.apache.joshua.decoder.hypergraph.ViterbiExtractor.getViterbiWordAlignments;
-import static org.apache.joshua.decoder.StructuredTranslationFactory.fromViterbiDerivation;
import static org.apache.joshua.util.FormatUtils.removeSentenceMarkers;
-import static java.util.Arrays.asList;
import java.io.BufferedWriter;
import java.io.IOException;
@@ -45,7 +44,7 @@ import org.slf4j.LoggerFactory;
* This class represents translated input objects (sentences or lattices). It is aware of the source
* sentence and id and contains the decoded hypergraph. Translation objects are returned by
* DecoderTask instances to the InputHandler, where they are assembled in order for output.
- *
+ *
* @author Matt Post post@cs.jhu.edu
* @author Felix Hieber fhieber@amazon.com
*/
@@ -66,17 +65,17 @@ public class Translation {
* Else it will use KBestExtractor to populate this list.
*/
private List<StructuredTranslation> structuredTranslations = null;
-
- public Translation(Sentence source, HyperGraph hypergraph,
+
+ public Translation(Sentence source, HyperGraph hypergraph,
List<FeatureFunction> featureFunctions, JoshuaConfiguration joshuaConfiguration) {
this.source = source;
-
+
/**
* Structured output from Joshua provides a way to programmatically access translation results
* from downstream applications, instead of writing results as strings to an output buffer.
*/
if (joshuaConfiguration.use_structured_output) {
-
+
if (joshuaConfiguration.topN == 0) {
/*
* Obtain Viterbi StructuredTranslation
@@ -84,7 +83,7 @@ public class Translation {
StructuredTranslation translation = fromViterbiDerivation(source, hypergraph, featureFunctions);
this.output = translation.getTranslationString();
structuredTranslations = Collections.singletonList(translation);
-
+
} else {
/*
* Get K-Best list of StructuredTranslations
@@ -107,9 +106,9 @@ public class Translation {
BufferedWriter out = new BufferedWriter(sw);
try {
-
+
if (hypergraph != null) {
-
+
long startTime = System.currentTimeMillis();
// We must put this weight as zero, otherwise we get an error when we try to retrieve it
@@ -161,20 +160,20 @@ public class Translation {
}
}
- float seconds = (float) (System.currentTimeMillis() - startTime) / 1000.0f;
+ float seconds = (System.currentTimeMillis() - startTime) / 1000.0f;
LOG.info("Input {}: {}-best extraction took {} seconds", id(),
joshuaConfiguration.topN, seconds);
} else {
-
+
// Failed translations and blank lines get empty formatted outputs
out.write(getFailedTranslationOutput(source, joshuaConfiguration));
out.newLine();
-
+
}
out.flush();
-
+
} catch (IOException e) {
throw new RuntimeException(e);
}
@@ -182,7 +181,7 @@ public class Translation {
this.output = sw.toString();
}
-
+
// remove state from StateMinimizingLanguageModel instances in features.
destroyKenLMStates(featureFunctions);
@@ -200,7 +199,7 @@ public class Translation {
public String toString() {
return output;
}
-
+
private String getFailedTranslationOutput(final Sentence source, final JoshuaConfiguration joshuaConfiguration) {
return joshuaConfiguration.outputFormat
.replace("%s", source.source())
@@ -211,7 +210,7 @@ public class Translation {
.replace("%f", "")
.replace("%c", "0.000");
}
-
+
/**
* Returns the StructuredTranslations
* if JoshuaConfiguration.use_structured_output == True.
@@ -225,7 +224,7 @@ public class Translation {
}
return structuredTranslations;
}
-
+
/**
* KenLM hack. If using KenLMFF, we need to tell KenLM to delete the pool used to create chart
* objects for this sentence.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java b/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
index 8b5c81a..0e5139a 100644
--- a/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
+++ b/src/main/java/org/apache/joshua/decoder/chart_parser/DotChart.java
@@ -19,11 +19,8 @@
package org.apache.joshua.decoder.chart_parser;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
-import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.ff.tm.Grammar;
import org.apache.joshua.decoder.ff.tm.Rule;
import org.apache.joshua.decoder.ff.tm.RuleCollection;
@@ -38,19 +35,19 @@ import org.slf4j.LoggerFactory;
/**
* The DotChart handles Earley-style implicit binarization of translation rules.
- *
+ *
* The {@link DotNode} object represents the (possibly partial) application of a synchronous rule.
* The implicit binarization is maintained with a pointer to the {@link Trie} node in the grammar,
* for easy retrieval of the next symbol to be matched. At every span (i,j) of the input sentence,
* every incomplete DotNode is examined to see whether it (a) needs a terminal and matches against
* the final terminal of the span or (b) needs a nonterminal and matches against a completed
* nonterminal in the main chart at some split point (k,j).
- *
+ *
* Once a rule is completed, it is entered into the {@link DotChart}. {@link DotCell} objects are
* used to group completed DotNodes over a span.
- *
+ *
* There is a separate DotChart for every grammar.
- *
+ *
* @author Zhifei Li, <zh...@gmail.com>
* @author Matt Post <po...@cs.jhu.edu>
* @author Kristy Hollingshead Seitz
@@ -108,7 +105,7 @@ class DotChart {
/**
* Constructs a new dot chart from a specified input lattice, a translation grammar, and a parse
* chart.
- *
+ *
* @param input A lattice which represents an input sentence.
* @param grammar A translation grammar.
* @param chart A CKY+ style chart in which completed span entries are stored.
@@ -142,17 +139,17 @@ class DotChart {
* This function computes all possible expansions of all rules over the provided span (i,j). By
* expansions, we mean the moving of the dot forward (from left to right) over a nonterminal or
* terminal symbol on the rule's source side.
- *
+ *
* There are two kinds of expansions:
- *
+ *
* <ol>
* <li>Expansion over a nonterminal symbol. For this kind of expansion, a rule has a dot
* immediately prior to a source-side nonterminal. The main Chart is consulted to see whether
* there exists a completed nonterminal with the same label. If so, the dot is advanced.
- *
+ *
* Discovering nonterminal expansions is a matter of enumerating all split points k such that i <
* k and k < j. The nonterminal symbol must exist in the main Chart over (k,j).
- *
+ *
* <li>Expansion over a terminal symbol. In this case, expansion is a simple matter of determing
* whether the input symbol at position j (the end of the span) matches the next symbol in the
* rule. This is equivalent to choosing a split point k = j - 1 and looking for terminal symbols
@@ -199,8 +196,6 @@ class DotChart {
// List<Trie> child_tnodes = ruleMatcher.produceMatchingChildTNodesTerminalevel(dotNode,
// last_word);
- List<Trie> child_tnodes = null;
-
Trie child_node = dotNode.trieNode.match(last_word);
if (null != child_node) {
addDotItem(child_node, i, j - 1 + arc_len, dotNode.antSuperNodes, null,
@@ -228,10 +223,10 @@ class DotChart {
* Attempt to combine an item in the dot chart with an item in the main chart to create a new item
* in the dot chart. The DotChart item is a {@link DotNode} begun at position i with the dot
* currently at position k, that is, a partially-applied rule.
- *
+ *
* In other words, this method looks for (proved) theorems or axioms in the completed chart that
* may apply and extend the dot position.
- *
+ *
* @param i Start index of a dot chart item
* @param k End index of a dot chart item; start index of a completed chart item
* @param j End index of a completed chart item
@@ -272,43 +267,10 @@ class DotChart {
}
}
- /*
- * We introduced the ability to have regular expressions in rules for matching against terminals.
- * For example, you could have the rule
- *
- * <pre> [X] ||| l?s herman?s ||| siblings </pre>
- *
- * When this is enabled for a grammar, we need to test against *all* (positive) outgoing arcs of
- * the grammar trie node to see if any of them match, and then return the whole set. This is quite
- * expensive, which is why you should only enable regular expressions for small grammars.
- */
-
- private ArrayList<Trie> matchAll(DotNode dotNode, int wordID) {
- ArrayList<Trie> trieList = new ArrayList<>();
- HashMap<Integer, ? extends Trie> childrenTbl = dotNode.trieNode.getChildren();
-
- if (childrenTbl != null && wordID >= 0) {
- // get all the extensions, map to string, check for *, build regexp
- for (Map.Entry<Integer, ? extends Trie> entry : childrenTbl.entrySet()) {
- Integer arcID = entry.getKey();
- if (arcID == wordID) {
- trieList.add(entry.getValue());
- } else {
- String arcWord = Vocabulary.word(arcID);
- if (Vocabulary.word(wordID).matches(arcWord)) {
- trieList.add(entry.getValue());
- }
- }
- }
- }
- return trieList;
- }
-
-
/**
* Creates a {@link DotNode} and adds it into the {@link DotChart} at the correct place. These
- * are (possibly incomplete) rule applications.
- *
+ * are (possibly incomplete) rule applications.
+ *
* @param tnode the trie node pointing to the location ("dot") in the grammar trie
* @param i
* @param j
@@ -382,10 +344,10 @@ class DotChart {
private final int i;
private final int j;
private Trie trieNode = null;
-
+
/* A list of grounded (over a span) nonterminals that have been crossed in traversing the rule */
private ArrayList<SuperNode> antSuperNodes = null;
-
+
/* The source lattice cost of applying the rule */
private final SourcePath srcPath;
@@ -396,11 +358,11 @@ class DotChart {
size = trieNode.getRuleCollection().getRules().size();
return String.format("DOTNODE i=%d j=%d #rules=%d #tails=%d", i, j, size, antSuperNodes.size());
}
-
+
/**
* Initialize a dot node with the span, grammar trie node, list of supernode tail pointers, and
* the lattice sourcepath.
- *
+ *
* @param i
* @param j
* @param trieNode
@@ -415,6 +377,7 @@ class DotChart {
this.srcPath = srcPath;
}
+ @Override
public boolean equals(Object obj) {
if (obj == null)
return false;
@@ -438,6 +401,7 @@ class DotChart {
* Technically the hash should include the span (i,j), but since DotNodes are grouped by span,
* this isn't necessary, and we gain something by not having to store the span.
*/
+ @Override
public int hashCode() {
return this.trieNode.hashCode();
}
@@ -446,7 +410,7 @@ class DotChart {
public boolean hasRules() {
return getTrieNode().getRuleCollection() != null && getTrieNode().getRuleCollection().getRules().size() != 0;
}
-
+
public RuleCollection getRuleCollection() {
return getTrieNode().getRuleCollection();
}
@@ -466,7 +430,7 @@ class DotChart {
public int begin() {
return i;
}
-
+
public int end() {
return j;
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java b/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
index 4e75af5..a9264e0 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/TargetBigram.java
@@ -80,8 +80,7 @@ public class TargetBigram extends StatefulFF {
this.vocab = new HashSet<>();
this.vocab.add("<s>");
this.vocab.add("</s>");
- try {
- LineReader lineReader = new LineReader(filename);
+ try(LineReader lineReader = new LineReader(filename);) {
for (String line: lineReader) {
if (lineReader.lineno() > maxTerms)
break;
@@ -189,7 +188,7 @@ public class TargetBigram extends StatefulFF {
}
/**
- * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
+ * TargetBigram features are only computed across hyperedges, so there is nothing to be done here.
*/
@Override
public float estimateCost(Rule rule, Sentence sentence) {
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
index 60e8d20..9933c73 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/fragmentlm/Tree.java
@@ -21,7 +21,14 @@ package org.apache.joshua.decoder.ff.fragmentlm;
import java.io.IOException;
import java.io.Serializable;
import java.io.StringReader;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.ff.fragmentlm.Trees.PennTreeReader;
@@ -39,7 +46,7 @@ import org.slf4j.LoggerFactory;
* addition to complete trees (the BP requires terminals to be immediately governed by a
* preterminal). To distinguish terminals from nonterminals in fragments, the former must be
* enclosed in double-quotes when read in.
- *
+ *
* @author Dan Klein
* @author Matt Post post@cs.jhu.edu
*/
@@ -72,13 +79,13 @@ public class Tree implements Serializable {
* This maps the flat right-hand sides of Joshua rules to the tree fragments they were derived
* from. It is used to lookup the fragment that language model fragments should be match against.
* For example, if the target (English) side of your rule is
- *
+ *
* [NP,1] said [SBAR,2]
- *
+ *
* we will retrieve the unflattened fragment
- *
+ *
* (S NP (VP (VBD said) SBAR))
- *
+ *
* which presumably was the fronter fragment used to derive the translation rule. With this in
* hand, we can iterate through our store of language model fragments to match them against this,
* following tail nodes if necessary.
@@ -114,7 +121,7 @@ public class Tree implements Serializable {
/**
* Computes the depth-one rule rooted at this node. If the node has no children, null is returned.
- *
+ *
* @return string representation of the rule
*/
public String getRule() {
@@ -182,7 +189,7 @@ public class Tree implements Serializable {
/**
* Clone the structure of the tree.
- *
+ *
* @return a cloned tree
*/
public Tree shallowClone() {
@@ -241,7 +248,7 @@ public class Tree implements Serializable {
* A tree is lexicalized if it has terminal nodes among the leaves of its frontier. For normal
* trees this is always true since they bottom out in terminals, but for fragments, this may or
* may not be true.
- *
+ *
* @return true if the tree is lexicalized
*/
public boolean isLexicalized() {
@@ -260,7 +267,7 @@ public class Tree implements Serializable {
/**
* The depth of a tree is the maximum distance from the root to any of the frontier nodes.
- *
+ *
* @return the tree depth
*/
public int getDepth() {
@@ -308,6 +315,7 @@ public class Tree implements Serializable {
this.label = Vocabulary.id(label);
}
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
toStringBuilder(sb);
@@ -317,13 +325,13 @@ public class Tree implements Serializable {
/**
* Removes the quotes around terminals. Note that the resulting tree could not be read back
* in by this class, since unquoted leaves are interpreted as nonterminals.
- *
+ *
* @return unquoted string
*/
public String unquotedString() {
return toString().replaceAll("\"", "");
}
-
+
public String escapedString() {
return toString().replaceAll(" ", "_");
}
@@ -349,7 +357,7 @@ public class Tree implements Serializable {
/**
* Get the set of all subtrees inside the tree by returning a tree rooted at each node. These are
* <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
+ *
* @return the <code>Set</code> of all subtrees in the tree.
*/
public Set<Tree> subTrees() {
@@ -359,7 +367,7 @@ public class Tree implements Serializable {
/**
* Get the list of all subtrees inside the tree by returning a tree rooted at each node. These are
* <i>not</i> copies, but all share structure. The tree is regarded as a subtree of itself.
- *
+ *
* @return the <code>List</code> of all subtrees in the tree.
*/
public List<Tree> subTreeList() {
@@ -369,7 +377,7 @@ public class Tree implements Serializable {
/**
* Add the set of all subtrees inside a tree (including the tree itself) to the given
* <code>Collection</code>.
- *
+ *
* @param n A collection of nodes to which the subtrees will be added
* @return The collection parameter with the subtrees added
*/
@@ -387,7 +395,7 @@ public class Tree implements Serializable {
* <code>iterator()</code> method required by the <code>Collections</code> interface. It does a
* preorder (children after node) traversal of the tree. (A possible extension to the class at
* some point would be to allow different traversal orderings via variant iterators.)
- *
+ *
* @return An interator over the nodes of the tree
*/
public TreeIterator iterator() {
@@ -403,10 +411,12 @@ public class Tree implements Serializable {
treeStack.add(Tree.this);
}
+ @Override
public boolean hasNext() {
return (!treeStack.isEmpty());
}
+ @Override
public Tree next() {
int lastIndex = treeStack.size() - 1;
Tree tr = treeStack.remove(lastIndex);
@@ -421,6 +431,7 @@ public class Tree implements Serializable {
/**
* Not supported
*/
+ @Override
public void remove() {
throw new UnsupportedOperationException();
}
@@ -454,7 +465,7 @@ public class Tree implements Serializable {
* to the leftmost (rightmost) pre-terminal in the tree. This facilitates using trees as language
* models. The arguments have to be passed in to preserve Java generics, even though this is only
* ever used with String versions.
- *
+ *
* @param sos presumably "<s>"
* @param eos presumably "</s>"
*/
@@ -469,7 +480,7 @@ public class Tree implements Serializable {
}
/**
- *
+ *
* @param symbol the marker to insert
* @param pos the position at which to insert
*/
@@ -492,7 +503,7 @@ public class Tree implements Serializable {
/**
* This is a convenience function for producing a fragment from its string representation.
- *
+ *
* @param ptbStr input string from which to produce a fragment
* @return the fragment
*/
@@ -511,8 +522,7 @@ public class Tree implements Serializable {
public static void readMapping(String fragmentMappingFile) {
/* Read in the rule / fragments mapping */
- try {
- LineReader reader = new LineReader(fragmentMappingFile);
+ try (LineReader reader = new LineReader(fragmentMappingFile);) {
for (String line : reader) {
String[] fields = line.split("\\s+\\|{3}\\s+");
if (fields.length != 2 || !fields[0].startsWith("(")) {
@@ -535,14 +545,14 @@ public class Tree implements Serializable {
* the internal fragment corresponding to the rule; this will be the top of the tree. We then
* recursively visit the derivation state objects, following the route through the hypergraph
* defined by them.
- *
+ *
* This function is like Tree#buildTree(DerivationState, int),
* but that one simply follows the best incoming hyperedge for each node.
- *
+ *
* @param rule for which corresponding internal fragment can be used to initialize the tree
* @param derivationStates array of state objects
* @param maxDepth of route through the hypergraph
- * @return the Tree
+ * @return the Tree
*/
public static Tree buildTree(Rule rule, DerivationState[] derivationStates, int maxDepth) {
Tree tree = getFragmentFromYield(rule.getEnglishWords());
@@ -566,7 +576,7 @@ public class Tree implements Serializable {
* indices in the Vocabulary, while negative indices are used to nonterminals. These negative
* indices are a *permutation* of the source side nonterminals, which contain the actual
* nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
* the incoming DerivationState items, which are ordered by the source side.
*/
ArrayList<Integer> tailIndices = new ArrayList<>();
@@ -604,23 +614,23 @@ public class Tree implements Serializable {
frontierTree.children = tree.children;
}
}
-
+
return tree;
}
-
+
/**
* <p>Builds a tree from the kth-best derivation state. This is done by initializing the tree with
* the internal fragment corresponding to the rule; this will be the top of the tree. We then
* recursively visit the derivation state objects, following the route through the hypergraph
* defined by them.</p>
- *
+ *
* @param derivationState array of state objects
* @param maxDepth of route through the hypergraph
* @return the Tree
*/
public static Tree buildTree(DerivationState derivationState, int maxDepth) {
Rule rule = derivationState.edge.getRule();
-
+
Tree tree = getFragmentFromYield(rule.getEnglishWords());
if (tree == null) {
@@ -628,7 +638,7 @@ public class Tree implements Serializable {
}
tree = tree.shallowClone();
-
+
LOG.debug("buildTree({})", tree);
if (rule.getArity() > 0 && maxDepth > 0) {
@@ -638,7 +648,7 @@ public class Tree implements Serializable {
* indices in the Vocabulary, while negative indices are used to nonterminals. These negative
* indices are a *permutation* of the source side nonterminals, which contain the actual
* nonterminal Vocabulary indices for the nonterminal names. Here, we convert this permutation
- * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
+ * to a nonnegative 0-based permutation and store it in tailIndices. This is used to index
* the incoming DerivationState items, which are ordered by the source side.
*/
ArrayList<Integer> tailIndices = new ArrayList<>();
@@ -667,16 +677,16 @@ public class Tree implements Serializable {
frontierTree.children = childTree.children;
}
}
-
+
return tree;
}
/**
* Takes a rule and its tail pointers and recursively constructs a tree (up to maxDepth).
- *
+ *
* This could be implemented by using the other buildTree() function and using the 1-best
* DerivationState.
- *
+ *
* @param rule {@link org.apache.joshua.decoder.ff.tm.Rule} to be used whilst building the tree
* @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s
* @param maxDepth to go in the tree
@@ -745,16 +755,16 @@ public class Tree implements Serializable {
return tree;
}
- public static void main(String[] args) {
- LineReader reader = new LineReader(System.in);
-
- for (String line : reader) {
- try {
- Tree tree = Tree.fromString(line);
- tree.insertSentenceMarkers();
- System.out.println(tree);
- } catch (Exception e) {
- System.out.println("");
+ public static void main(String[] args) throws IOException {
+ try (LineReader reader = new LineReader(System.in);) {
+ for (String line : reader) {
+ try {
+ Tree tree = Tree.fromString(line);
+ tree.insertSentenceMarkers();
+ System.out.println(tree);
+ } catch (Exception e) {
+ System.out.println("");
+ }
}
}
@@ -762,14 +772,14 @@ public class Tree implements Serializable {
* Tree fragment = Tree
* .fromString("(TOP (S (NP (DT the) (NN boy)) (VP (VBD ate) (NP (DT the) (NN food)))))");
* fragment.insertSentenceMarkers("<s>", "</s>");
- *
+ *
* System.out.println(fragment);
- *
+ *
* ArrayList<Tree> trees = new ArrayList<Tree>(); trees.add(Tree.fromString("(NN \"mat\")"));
* trees.add(Tree.fromString("(S (NP DT NN) VP)"));
* trees.add(Tree.fromString("(S (NP (DT \"the\") NN) VP)"));
* trees.add(Tree.fromString("(S (NP (DT the) NN) VP)"));
- *
+ *
* for (Tree tree : trees) { System.out.println(String.format("TREE %s DEPTH %d LEX? %s", tree,
* tree.getDepth(), tree.isLexicalized())); }
*/
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
index 93d54ed..044c85f 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/KenLM.java
@@ -29,7 +29,7 @@ import org.slf4j.LoggerFactory;
* feature functions KenLMFF and LanguageModelFF. KenLMFF uses the RuleScore() interface in
* lm/left.hh, returning a state pointer representing the KenLM state, while LangaugeModelFF handles
* state by itself and just passes in the ngrams for scoring.
- *
+ *
* @author Kenneth Heafield
* @author Matt Post post@cs.jhu.edu
*/
@@ -58,11 +58,11 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
private static native float probForString(long ptr, String[] words);
private static native boolean isKnownWord(long ptr, String word);
-
+
private static native boolean isLmOov(long ptr, int word);
private static native StateProbPair probRule(long ptr, long pool, long words[]);
-
+
private static native float estimateRule(long ptr, long words[]);
private static native float probString(long ptr, int words[], int start);
@@ -98,7 +98,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
}
}
- public class KenLMLoadException extends RuntimeException {
+ public static class KenLMLoadException extends RuntimeException {
public KenLMLoadException(UnsatisfiedLinkError e) {
super(e);
@@ -117,10 +117,12 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
destroy(pointer);
}
+ @Override
public int getOrder() {
return ngramOrder;
}
+ @Override
public boolean registerWord(String word, int id) {
return registerWord(pointer, word, id);
}
@@ -149,10 +151,10 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
* rule). Nonterminals have a negative value so KenLM can distinguish them. The sentence number is
* needed so KenLM knows which memory pool to use. When finished, it returns the updated KenLM
* state and the LM probability incurred along this rule.
- *
+ *
* @param words array of words
* @param poolPointer todo
- * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
+ * @return the updated {@link org.apache.joshua.decoder.ff.lm.KenLM.StateProbPair} e.g.
* KenLM state and the LM probability incurred along this rule
*/
public StateProbPair probRule(long[] words, long poolPointer) {
@@ -171,7 +173,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
/**
* Public facing function that estimates the cost of a rule, which value is used for sorting
* rules during cube pruning.
- *
+ *
* @param words array of words
* @return the estimated cost of the rule (the (partial) n-gram probabilities of all words in the rule)
*/
@@ -182,7 +184,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
} catch (NoSuchMethodError e) {
throw new RuntimeException(e);
}
-
+
return estimate;
}
@@ -193,7 +195,7 @@ public class KenLM implements NGramLanguageModel, Comparable<KenLM> {
public String getStartSymbol() {
return Vocabulary.START_SYM;
}
-
+
/**
* Returns whether the given Vocabulary ID is unknown to the
* KenLM vocabulary. This can be used for a LanguageModel_OOV features
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java b/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
index 9bfccb0..0615077 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/lm/buildin_lm/TrieLM.java
@@ -44,12 +44,12 @@ import org.slf4j.LoggerFactory;
* <p>
* The trie itself represents language model context.
* <p>
- * Conceptually, each node in the trie stores a map
+ * Conceptually, each node in the trie stores a map
* from conditioning word to log probability.
* <p>
- * Additionally, each node in the trie stores
+ * Additionally, each node in the trie stores
* the backoff weight for that context.
- *
+ *
* @author Lane Schwartz
* @see <a href="http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html">SRILM ngram-discount documentation</a>
*/
@@ -63,23 +63,23 @@ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
private static final int ROOT_NODE_ID = 0;
- /**
- * Maps from (node id, word id for child) --> node id of child.
+ /**
+ * Maps from (node id, word id for child) --> node id of child.
*/
private final Map<Long,Integer> children;
/**
- * Maps from (node id, word id for lookup word) -->
- * log prob of lookup word given context
- *
+ * Maps from (node id, word id for lookup word) -->
+ * log prob of lookup word given context
+ *
* (the context is defined by where you are in the tree).
*/
private final Map<Long,Float> logProbs;
/**
- * Maps from (node id) -->
- * backoff weight for that context
- *
+ * Maps from (node id) -->
+ * backoff weight for that context
+ *
* (the context is defined by where you are in the tree).
*/
private final Map<Integer,Float> backoffs;
@@ -90,7 +90,7 @@ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
/**
* Constructs a language model object from the specified ARPA file.
- *
+ *
* @param arpaFile input ARPA file
* @throws FileNotFoundException if the input file cannot be located
*/
@@ -149,7 +149,7 @@ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
{
// Find where the backoff should be stored
int backoffNodeID = ROOT_NODE_ID;
- {
+ {
long backoffNodeKey = Bits.encodeAsLong(backoffNodeID, word);
int wordChildID;
if (children.containsKey(backoffNodeKey)) {
@@ -188,62 +188,13 @@ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
@Override
- protected double logProbabilityOfBackoffState_helper(
- int[] ngram, int order, int qtyAdditionalBackoffWeight
- ) {
+ protected double logProbabilityOfBackoffState_helper(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
throw new UnsupportedOperationException("probabilityOfBackoffState_helper undefined for TrieLM");
}
@Override
protected float ngramLogProbability_helper(int[] ngram, int order) {
-
-// float logProb = (float) -JoshuaConfiguration.lm_ceiling_cost;//Float.NEGATIVE_INFINITY; // log(0.0f)
- float backoff = 0.0f; // log(1.0f)
-
- int i = ngram.length - 1;
- int word = ngram[i];
- i -= 1;
-
- int nodeID = ROOT_NODE_ID;
-
- while (true) {
-
- {
- long key = Bits.encodeAsLong(nodeID, word);
- if (logProbs.containsKey(key)) {
-// logProb = logProbs.get(key);
- backoff = 0.0f; // log(0.0f)
- }
- }
-
- if (i < 0) {
- break;
- }
-
- {
- long key = Bits.encodeAsLong(nodeID, ngram[i]);
-
- if (children.containsKey(key)) {
- nodeID = children.get(key);
-
- backoff += backoffs.get(nodeID);
-
- i -= 1;
-
- } else {
- break;
- }
- }
-
- }
-
-// double result = logProb + backoff;
-// if (result < -JoshuaConfiguration.lm_ceiling_cost) {
-// result = -JoshuaConfiguration.lm_ceiling_cost;
-// }
-//
-// return result;
- return (Float) null;
+ throw new UnsupportedOperationException();
}
public Map<Long,Integer> getChildren() {
@@ -264,66 +215,65 @@ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel {
int n = Integer.valueOf(args[2]);
LOG.info("N-gram order will be {}", n);
- Scanner scanner = new Scanner(new File(args[1]));
+ try (Scanner scanner = new Scanner(new File(args[1]));) {
+ LinkedList<String> wordList = new LinkedList<>();
+ LinkedList<String> window = new LinkedList<>();
- LinkedList<String> wordList = new LinkedList<>();
- LinkedList<String> window = new LinkedList<>();
+ LOG.info("Starting to scan {}", args[1]);
+ while (scanner.hasNext()) {
- LOG.info("Starting to scan {}", args[1]);
- while (scanner.hasNext()) {
+ LOG.info("Getting next line...");
+ String line = scanner.nextLine();
+ LOG.info("Line: {}", line);
- LOG.info("Getting next line...");
- String line = scanner.nextLine();
- LOG.info("Line: {}", line);
+ String[] words = Regex.spaces.split(line);
+ wordList.clear();
- String[] words = Regex.spaces.split(line);
- wordList.clear();
+ wordList.add("<s>");
+ Collections.addAll(wordList, words);
+ wordList.add("</s>");
- wordList.add("<s>");
- Collections.addAll(wordList, words);
- wordList.add("</s>");
-
- ArrayList<Integer> sentence = new ArrayList<>();
- // int[] ids = new int[wordList.size()];
- for (String aWordList : wordList) {
- sentence.add(Vocabulary.id(aWordList));
- // ids[i] = ;
- }
+ ArrayList<Integer> sentence = new ArrayList<>();
+ // int[] ids = new int[wordList.size()];
+ for (String aWordList : wordList) {
+ sentence.add(Vocabulary.id(aWordList));
+ // ids[i] = ;
+ }
+ while (!wordList.isEmpty()) {
+ window.clear();
+ {
+ int i = 0;
+ for (String word : wordList) {
+ if (i >= n)
+ break;
+ window.add(word);
+ i++;
+ }
+ wordList.remove();
+ }
- while (! wordList.isEmpty()) {
- window.clear();
+ {
+ int i = 0;
+ int[] wordIDs = new int[window.size()];
+ for (String word : window) {
+ wordIDs[i] = Vocabulary.id(word);
+ i++;
+ }
- {
- int i=0;
- for (String word : wordList) {
- if (i>=n) break;
- window.add(word);
- i++;
+ LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
}
- wordList.remove();
}
- {
- int i=0;
- int[] wordIDs = new int[window.size()];
- for (String word : window) {
- wordIDs[i] = Vocabulary.id(word);
- i++;
- }
+ double logProb = lm.sentenceLogProbability(sentence, n, 2);// .ngramLogProbability(ids,
+ // n);
+ double prob = Math.exp(logProb);
- LOG.info("logProb {} = {}", window, lm.ngramLogProbability(wordIDs, n));
- }
+ LOG.info("Total logProb = {}", logProb);
+ LOG.info("Total prob = {}", prob);
}
-
- double logProb = lm.sentenceLogProbability(sentence, n, 2);//.ngramLogProbability(ids, n);
- double prob = Math.exp(logProb);
-
- LOG.info("Total logProb = {}", logProb);
- LOG.info("Total prob = {}", prob);
}
-
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java b/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
index 2424a1e..e8242f6 100644
--- a/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
+++ b/src/main/java/org/apache/joshua/decoder/ff/tm/CreateGlueGrammar.java
@@ -30,7 +30,6 @@ import java.util.Set;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.util.io.LineReader;
-
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
@@ -46,7 +45,7 @@ public class CreateGlueGrammar {
@Option(name = "--grammar", aliases = {"-g"}, required = true, usage = "provide grammar to determine list of NonTerminal symbols.")
private String grammarPath;
-
+
@Option(name = "--goal", aliases = {"-goal"}, required = false, usage = "specify custom GOAL symbol. Default: 'GOAL'")
private final String goalSymbol = cleanNonTerminal(new JoshuaConfiguration().goal_symbol);
@@ -59,9 +58,9 @@ public class CreateGlueGrammar {
private static final String R_END = "[%1$s] ||| [%1$s,1] </s> ||| [%1$s,1] </s> ||| 0";
// [GOAL] ||| <s> [X,1] </s> ||| <s> [X,1] </s> ||| 0
private static final String R_TOP = "[%1$s] ||| <s> [%2$s,1] </s> ||| <s> [%2$s,1] </s> ||| 0";
-
+
private void run() throws IOException {
-
+
File grammar_file = new File(grammarPath);
if (!grammar_file.exists()) {
throw new IOException("Grammar file doesn't exist: " + grammarPath);
@@ -78,38 +77,39 @@ public class CreateGlueGrammar {
}
}
// otherwise we collect cleaned left-hand sides from the rules in the text grammar.
- } else {
- final LineReader reader = new LineReader(grammarPath);
- while (reader.hasNext()) {
- final String line = reader.next();
- int lhsStart = line.indexOf("[") + 1;
- int lhsEnd = line.indexOf("]");
- if (lhsStart < 1 || lhsEnd < 0) {
- LOG.info("malformed rule: {}\n", line);
- continue;
+ } else {
+ try (final LineReader reader = new LineReader(grammarPath);) {
+ while (reader.hasNext()) {
+ final String line = reader.next();
+ int lhsStart = line.indexOf("[") + 1;
+ int lhsEnd = line.indexOf("]");
+ if (lhsStart < 1 || lhsEnd < 0) {
+ LOG.info("malformed rule: {}\n", line);
+ continue;
+ }
+ final String lhs = line.substring(lhsStart, lhsEnd);
+ nonTerminalSymbols.add(lhs);
}
- final String lhs = line.substring(lhsStart, lhsEnd);
- nonTerminalSymbols.add(lhs);
}
}
-
+
LOG.info("{} nonTerminal symbols read: {}", nonTerminalSymbols.size(),
nonTerminalSymbols.toString());
// write glue rules to stdout
-
+
System.out.println(String.format(R_START, goalSymbol));
-
+
for (String nt : nonTerminalSymbols)
System.out.println(String.format(R_TWO, goalSymbol, nt));
-
+
System.out.println(String.format(R_END, goalSymbol));
-
+
for (String nt : nonTerminalSymbols)
System.out.println(String.format(R_TOP, goalSymbol, nt));
}
-
+
public static void main(String[] args) throws IOException {
final CreateGlueGrammar glueCreator = new CreateGlueGrammar();
final CmdLineParser parser = new CmdLineParser(glueCreator);
[10/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
index 6eaced4,0000000..f51a5b3
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/mira/Optimizer.java
@@@ -1,643 -1,0 +1,641 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.mira;
+
++import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
- import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.Vector;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.metrics.EvaluationMetric;
+
+// this class implements the MIRA algorithm
+public class Optimizer {
+ public Optimizer(Vector<String> _output, boolean[] _isOptimizable, double[] _initialLambda,
+ HashMap<String, String>[] _feat_hash, HashMap<String, String>[] _stats_hash) {
+ output = _output; // (not used for now)
+ isOptimizable = _isOptimizable;
+ initialLambda = _initialLambda; // initial weights array
+ paramDim = initialLambda.length - 1;
+ initialLambda = _initialLambda;
+ feat_hash = _feat_hash; // feature hash table
+ stats_hash = _stats_hash; // suff. stats hash table
+ finalLambda = new double[initialLambda.length];
+ for (int i = 0; i < finalLambda.length; i++)
+ finalLambda[i] = initialLambda[i];
+ }
+
+ // run MIRA for one epoch
+ public double[] runOptimizer() {
+ List<Integer> sents = new ArrayList<Integer>();
+ for (int i = 0; i < sentNum; ++i)
+ sents.add(i);
+ double[] avgLambda = new double[initialLambda.length]; // only needed if averaging is required
+ for (int i = 0; i < initialLambda.length; i++)
+ avgLambda[i] = 0.0;
+ double[] bestLambda = new double[initialLambda.length]; // only needed if averaging is required
+ for (int i = 0; i < initialLambda.length; i++)
+ bestLambda[i] = 0.0;
+ double bestMetricScore = evalMetric.getToBeMinimized() ? PosInf : NegInf;
+ int bestIter = 0;
+ for (int iter = 0; iter < miraIter; ++iter) {
+ System.arraycopy(finalLambda, 1, initialLambda, 1, paramDim);
+ if (needShuffle)
+ Collections.shuffle(sents);
+
+ double oraMetric, oraScore, predMetric, predScore;
+ double[] oraPredScore = new double[4];
+ double eta = 1.0; // learning rate, will not be changed if run percep
+ double avgEta = 0; // average eta, just for analysis
+ double loss = 0;
+ double diff = 0;
+ double featNorm = 0;
+ double sumMetricScore = 0;
+ double sumModelScore = 0;
+ String oraFeat = "";
+ String predFeat = "";
+ String[] oraPredFeat = new String[2];
+ String[] vecOraFeat;
+ String[] vecPredFeat;
+ String[] featInfo;
+ int thisBatchSize = 0;
+ int numBatch = 0;
- int numUpdate = 0;
- Iterator it;
+ Integer diffFeatId;
+
+ // update weights
+ Integer s;
+ int sentCount = 0;
+ while( sentCount < sentNum ) {
+ loss = 0;
+ thisBatchSize = batchSize;
+ ++numBatch;
+ HashMap<Integer, Double> featDiff = new HashMap<Integer, Double>();
+ for(int b = 0; b < batchSize; ++b ) {
+ //find out oracle and prediction
+ s = sents.get(sentCount);
+ // find out oracle and prediction
+ findOraPred(s, oraPredScore, oraPredFeat, finalLambda, featScale);
-
++
+ // the model scores here are already scaled in findOraPred
+ oraMetric = oraPredScore[0];
+ oraScore = oraPredScore[1];
+ predMetric = oraPredScore[2];
+ predScore = oraPredScore[3];
+ oraFeat = oraPredFeat[0];
+ predFeat = oraPredFeat[1];
-
++
+ // update the scale
+ if (needScale) { // otherwise featscale remains 1.0
+ sumMetricScore += java.lang.Math.abs(oraMetric + predMetric);
+ // restore the original model score
+ sumModelScore += java.lang.Math.abs(oraScore + predScore) / featScale;
+
+ if (sumModelScore / sumMetricScore > scoreRatio)
+ featScale = sumMetricScore / sumModelScore;
+ }
+
+ vecOraFeat = oraFeat.split("\\s+");
+ vecPredFeat = predFeat.split("\\s+");
-
++
+ //accumulate difference feature vector
+ if ( b == 0 ) {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the 2nd feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ } else {
+ for (int i = 0; i < vecOraFeat.length; i++) {
+ featInfo = vecOraFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)+Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new oracle feature vector
+ featDiff.put(diffFeatId, Double.parseDouble(featInfo[1]));
+ }
+ for (int i = 0; i < vecPredFeat.length; i++) {
+ featInfo = vecPredFeat[i].split("=");
+ diffFeatId = Integer.parseInt(featInfo[0]);
+ if (featDiff.containsKey(diffFeatId)) { //overlapping features
+ diff = featDiff.get(diffFeatId)-Double.parseDouble(featInfo[1]);
+ if ( Math.abs(diff) > 1e-20 )
+ featDiff.put(diffFeatId, diff);
+ else
+ featDiff.remove(diffFeatId);
+ }
+ else //features only firing in the new prediction feature vector
+ featDiff.put(diffFeatId, -1.0*Double.parseDouble(featInfo[1]));
+ }
+ }
+ if (!runPercep) { // otherwise eta=1.0
+ // remember the model scores here are already scaled
+ double singleLoss = evalMetric.getToBeMinimized() ?
+ (predMetric - oraMetric) - (oraScore - predScore) / featScale
+ : (oraMetric - predMetric) - (oraScore - predScore) / featScale;
+ loss += singleLoss;
+ }
+ ++sentCount;
+ if( sentCount >= sentNum ) {
+ thisBatchSize = b + 1;
+ break;
+ }
+ } //for(int b = 0; b < batchSize; ++b)
+
+ if (!runPercep) { // otherwise eta=1.0
+ featNorm = 0;
+ Collection<Double> allDiff = featDiff.values();
- for (it = allDiff.iterator(); it.hasNext();) {
- diff = (Double) it.next();
++ for (Iterator<Double> it = allDiff.iterator(); it.hasNext();) {
++ diff = it.next();
+ featNorm += diff * diff / ( thisBatchSize * thisBatchSize );
+ }
+ }
+ if( loss <= 0 )
+ eta = 0;
+ else {
+ loss /= thisBatchSize;
+ // feat vector not scaled before
+ eta = C < loss / featNorm ? C : loss / featNorm;
+ }
+ avgEta += eta;
+ Set<Integer> diffFeatSet = featDiff.keySet();
- it = diffFeatSet.iterator();
++ Iterator<Integer> it = diffFeatSet.iterator();
+ if ( java.lang.Math.abs(eta) > 1e-20 ) {
+ while (it.hasNext()) {
- diffFeatId = (Integer) it.next();
++ diffFeatId = it.next();
+ finalLambda[diffFeatId] =
+ finalLambda[diffFeatId] + eta * featDiff.get(diffFeatId) / thisBatchSize;
+ }
+ }
+ if (needAvg) {
+ for (int i = 0; i < avgLambda.length; ++i)
+ avgLambda[i] += finalLambda[i];
+ }
+ } //while( sentCount < sentNum )
+
+ avgEta /= numBatch;
+
+ /*
+ * for( int i=0; i<finalLambda.length; i++ ) System.out.print(finalLambda[i]+" ");
+ * System.out.println(); System.exit(0);
+ */
+
+ double initMetricScore;
+ if(iter == 0 ) {
+ initMetricScore = computeCorpusMetricScore(initialLambda);
+ if(needAvg)
+ finalMetricScore = computeCorpusMetricScore(avgLambda);
+ else
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ } else {
+ initMetricScore = finalMetricScore;
+ if(needAvg)
+ finalMetricScore = computeCorpusMetricScore(avgLambda);
+ else
+ finalMetricScore = computeCorpusMetricScore(finalLambda);
+ }
+
+ if(evalMetric.getToBeMinimized()) {
+ if( finalMetricScore < bestMetricScore ) {
+ bestMetricScore = finalMetricScore;
+ bestIter = iter;
+ for( int i = 0; i < finalLambda.length; ++i )
+ bestLambda[i] = needAvg ? avgLambda[i] : finalLambda[i];
+ }
+ } else {
+ if( finalMetricScore > bestMetricScore ) {
+ bestMetricScore = finalMetricScore;
+ bestIter = iter;
+ for( int i = 0; i < finalLambda.length; ++i )
+ bestLambda[i] = needAvg ? avgLambda[i] : finalLambda[i];
+ }
+ }
+
+ if ( iter == miraIter - 1 ) {
+ for (int i = 0; i < finalLambda.length; ++i)
+ finalLambda[i] =
+ needAvg ? bestLambda[i] / ( numBatch * ( bestIter + 1 ) ) : bestLambda[i];
+ }
+
+ // prepare the printing info
+ String result = "Iter " + iter + ": Avg learning rate=" + String.format("%.4f", avgEta);
+ result += " Initial " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", initMetricScore) + " Final " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", finalMetricScore);
+ output.add(result);
+ } // for ( int iter = 0; iter < miraIter; ++iter )
+ String result = "Best " + evalMetric.get_metricName() + "="
+ + String.format("%.4f", bestMetricScore)
+ + " (iter = " + bestIter + ")\n";
+ output.add(result);
+ finalMetricScore = bestMetricScore;
+
+ // non-optimizable weights should remain unchanged
+ ArrayList<Double> cpFixWt = new ArrayList<Double>();
+ for (int i = 1; i < isOptimizable.length; ++i) {
+ if (!isOptimizable[i])
+ cpFixWt.add(finalLambda[i]);
+ }
+ normalizeLambda(finalLambda);
+ int countNonOpt = 0;
+ for (int i = 1; i < isOptimizable.length; ++i) {
+ if (!isOptimizable[i]) {
+ finalLambda[i] = cpFixWt.get(countNonOpt);
+ ++countNonOpt;
+ }
+ }
+ return finalLambda;
+ }
+
+ public double computeCorpusMetricScore(double[] finalLambda) {
+ int suffStatsCount = evalMetric.get_suffStatsCount();
+ double modelScore;
+ double maxModelScore;
+ Set<String> candSet;
+ String candStr;
+ String[] feat_str;
+ String[] tmpStatsVal = new String[suffStatsCount];
+ int[] corpusStatsVal = new int[suffStatsCount];
+ for (int i = 0; i < suffStatsCount; i++)
+ corpusStatsVal[i] = 0;
+
+ for (int i = 0; i < sentNum; i++) {
+ candSet = feat_hash[i].keySet();
+ // find out the 1-best candidate for each sentence
+ // this depends on the training mode
+ maxModelScore = NegInf;
- for (Iterator it = candSet.iterator(); it.hasNext();) {
++ for (Iterator<String> it = candSet.iterator(); it.hasNext();) {
+ modelScore = 0.0;
+ candStr = it.next().toString();
+ feat_str = feat_hash[i].get(candStr).split("\\s+");
+ String[] feat_info;
+ for (int f = 0; f < feat_str.length; f++) {
+ feat_info = feat_str[f].split("=");
+ modelScore += Double.parseDouble(feat_info[1]) * finalLambda[Vocabulary.id(feat_info[0])];
+ }
+ if (maxModelScore < modelScore) {
+ maxModelScore = modelScore;
+ tmpStatsVal = stats_hash[i].get(candStr).split("\\s+"); // save the
+ // suff stats
+ }
+ }
+
+ for (int j = 0; j < suffStatsCount; j++)
+ corpusStatsVal[j] += Integer.parseInt(tmpStatsVal[j]); // accumulate
+ // corpus-leve
+ // suff stats
+ } // for( int i=0; i<sentNum; i++ )
+
+ return evalMetric.score(corpusStatsVal);
+ }
+
+ private void findOraPred(int sentId, double[] oraPredScore, String[] oraPredFeat,
+ double[] lambda, double featScale) {
+ double oraMetric = 0, oraScore = 0, predMetric = 0, predScore = 0;
+ String oraFeat = "", predFeat = "";
+ double candMetric = 0, candScore = 0; // metric and model scores for each cand
+ Set<String> candSet = stats_hash[sentId].keySet();
+ String cand = "";
+ String feats = "";
+ String oraCand = ""; // only used when BLEU/TER-BLEU is used as metric
+ String[] featStr;
+ String[] featInfo;
+
+ int actualFeatId;
+ double bestOraScore;
+ double worstPredScore;
+
+ if (oraSelectMode == 1)
+ bestOraScore = NegInf; // larger score will be selected
+ else {
+ if (evalMetric.getToBeMinimized())
+ bestOraScore = PosInf; // smaller score will be selected
+ else
+ bestOraScore = NegInf;
+ }
+
+ if (predSelectMode == 1 || predSelectMode == 2)
+ worstPredScore = NegInf; // larger score will be selected
+ else {
+ if (evalMetric.getToBeMinimized())
+ worstPredScore = NegInf; // larger score will be selected
+ else
+ worstPredScore = PosInf;
+ }
+
- for (Iterator it = candSet.iterator(); it.hasNext();) {
++ for (Iterator<String> it = candSet.iterator(); it.hasNext();) {
+ cand = it.next().toString();
+ candMetric = computeSentMetric(sentId, cand); // compute metric score
+
+ // start to compute model score
+ candScore = 0;
+ featStr = feat_hash[sentId].get(cand).split("\\s+");
+ feats = "";
+
+ for (int i = 0; i < featStr.length; i++) {
+ featInfo = featStr[i].split("=");
+ actualFeatId = Vocabulary.id(featInfo[0]);
+ candScore += Double.parseDouble(featInfo[1]) * lambda[actualFeatId];
+ if ((actualFeatId < isOptimizable.length && isOptimizable[actualFeatId])
+ || actualFeatId >= isOptimizable.length)
+ feats += actualFeatId + "=" + Double.parseDouble(featInfo[1]) + " ";
+ }
+
+ candScore *= featScale; // scale the model score
+
+ // is this cand oracle?
+ if (oraSelectMode == 1) {// "hope", b=1, r=1
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (bestOraScore <= (candScore - candMetric)) {
+ bestOraScore = candScore - candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= (candScore + candMetric)) {
+ bestOraScore = candScore + candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ } else {// best metric score(ex: max BLEU), b=1, r=0
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (bestOraScore >= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ } else {
+ if (bestOraScore <= candMetric) {
+ bestOraScore = candMetric;
+ oraMetric = candMetric;
+ oraScore = candScore;
+ oraFeat = feats;
+ oraCand = cand;
+ }
+ }
+ }
+
+ // is this cand prediction?
+ if (predSelectMode == 1) {// "fear"
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (worstPredScore <= (candScore + candMetric)) {
+ worstPredScore = candScore + candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore <= (candScore - candMetric)) {
+ worstPredScore = candScore - candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ } else if (predSelectMode == 2) {// model prediction(max model score)
+ if (worstPredScore <= candScore) {
+ worstPredScore = candScore;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {// worst metric score(ex: min BLEU)
+ if (evalMetric.getToBeMinimized()) {// if the smaller the metric score, the better
+ if (worstPredScore <= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ } else {
+ if (worstPredScore >= candMetric) {
+ worstPredScore = candMetric;
+ predMetric = candMetric;
+ predScore = candScore;
+ predFeat = feats;
+ }
+ }
+ }
+ }
+
+ oraPredScore[0] = oraMetric;
+ oraPredScore[1] = oraScore;
+ oraPredScore[2] = predMetric;
+ oraPredScore[3] = predScore;
+ oraPredFeat[0] = oraFeat;
+ oraPredFeat[1] = predFeat;
+
+ // update the BLEU metric statistics if pseudo corpus is used to compute BLEU/TER-BLEU
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ bleuHistory[sentId][j] = R * bleuHistory[sentId][j] + Integer.parseInt(statVal_str[j]);
+ }
+
+ if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ String statString;
+ String[] statVal_str;
+ statString = stats_hash[sentId].get(oraCand);
+ statVal_str = statString.split("\\s+");
+
+ for (int j = 0; j < evalMetric.get_suffStatsCount() - 2; j++)
+ bleuHistory[sentId][j] = R * bleuHistory[sentId][j] + Integer.parseInt(statVal_str[j + 2]); // the
+ // first
+ // 2
+ // stats
+ // are
+ // TER
+ // stats
+ }
+ }
+
+ // compute *sentence-level* metric score for cand
+ private double computeSentMetric(int sentId, String cand) {
+ String statString;
+ String[] statVal_str;
+ int[] statVal = new int[evalMetric.get_suffStatsCount()];
+
+ statString = stats_hash[sentId].get(cand);
+ statVal_str = statString.split("\\s+");
+
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = (int) (Integer.parseInt(statVal_str[j]) + bleuHistory[sentId][j]);
+ } else if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ for (int j = 0; j < evalMetric.get_suffStatsCount() - 2; j++)
+ statVal[j + 2] = (int) (Integer.parseInt(statVal_str[j + 2]) + bleuHistory[sentId][j]); // only
+ // modify
+ // the
+ // BLEU
+ // stats
+ // part(TER
+ // has
+ // 2
+ // stats)
+ } else { // in all other situations, use normal stats
+ for (int j = 0; j < evalMetric.get_suffStatsCount(); j++)
+ statVal[j] = Integer.parseInt(statVal_str[j]);
+ }
+
+ return evalMetric.score(statVal);
+ }
+
+ // from ZMERT
+ private void normalizeLambda(double[] origLambda) {
+ // private String[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ int normalizationMethod = (int) normalizationOptions[0];
+ double scalingFactor = 1.0;
+ if (normalizationMethod == 0) {
+ scalingFactor = 1.0;
+ } else if (normalizationMethod == 1) {
+ int c = (int) normalizationOptions[2];
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[c]);
+ } else if (normalizationMethod == 2) {
+ double maxAbsVal = -1;
+ int maxAbsVal_c = 0;
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) > maxAbsVal) {
+ maxAbsVal = Math.abs(origLambda[c]);
+ maxAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[maxAbsVal_c]);
+
+ } else if (normalizationMethod == 3) {
+ double minAbsVal = PosInf;
+ int minAbsVal_c = 0;
+
+ for (int c = 1; c <= paramDim; ++c) {
+ if (Math.abs(origLambda[c]) < minAbsVal) {
+ minAbsVal = Math.abs(origLambda[c]);
+ minAbsVal_c = c;
+ }
+ }
+ scalingFactor = normalizationOptions[1] / Math.abs(origLambda[minAbsVal_c]);
+
+ } else if (normalizationMethod == 4) {
+ double pow = normalizationOptions[1];
+ double norm = L_norm(origLambda, pow);
+ scalingFactor = normalizationOptions[2] / norm;
+ }
+
+ for (int c = 1; c <= paramDim; ++c) {
+ origLambda[c] *= scalingFactor;
+ }
+ }
+
+ // from ZMERT
+ private double L_norm(double[] A, double pow) {
+ // calculates the L-pow norm of A[]
+ // NOTE: this calculation ignores A[0]
+ double sum = 0.0;
+ for (int i = 1; i < A.length; ++i)
+ sum += Math.pow(Math.abs(A[i]), pow);
+
+ return Math.pow(sum, 1 / pow);
+ }
+
+ public static double getScale() {
+ return featScale;
+ }
+
+ public static void initBleuHistory(int sentNum, int statCount) {
+ bleuHistory = new double[sentNum][statCount];
+ for (int i = 0; i < sentNum; i++) {
+ for (int j = 0; j < statCount; j++) {
+ bleuHistory[i][j] = 0.0;
+ }
+ }
+ }
-
++
+ public double getMetricScore() {
+ return finalMetricScore;
+ }
-
++
+ private Vector<String> output;
+ private double[] initialLambda;
+ private double[] finalLambda;
+ private double finalMetricScore;
+ private HashMap<String, String>[] feat_hash;
+ private HashMap<String, String>[] stats_hash;
+ private int paramDim;
+ private boolean[] isOptimizable;
+ public static int sentNum;
+ public static int miraIter; // MIRA internal iterations
+ public static int oraSelectMode;
+ public static int predSelectMode;
+ public static int batchSize;
+ public static boolean needShuffle;
+ public static boolean needScale;
+ public static double scoreRatio;
+ public static boolean runPercep;
+ public static boolean needAvg;
+ public static boolean usePseudoBleu;
+ public static double featScale = 1.0; // scale the features in order to make the model score
+ // comparable with metric score
+ // updates in each epoch if necessary
+ public static double C; // relaxation coefficient
+ public static double R; // corpus decay(used only when pseudo corpus is used to compute BLEU)
+ public static EvaluationMetric evalMetric;
+ public static double[] normalizationOptions;
+ public static double[][] bleuHistory;
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+}
[02/17] incubator-joshua git commit: Fix a number of issues: - Reader
now implements autocloseable - Close various leaks from LineReader -
LineReader no longer implements custom finalize(). Resources should be
explicitly closed when no longer needed. T
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/Constants.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/Constants.java b/src/main/java/org/apache/joshua/util/Constants.java
index 3d4139d..9612a35 100644
--- a/src/main/java/org/apache/joshua/util/Constants.java
+++ b/src/main/java/org/apache/joshua/util/Constants.java
@@ -20,17 +20,17 @@ package org.apache.joshua.util;
/***
* One day, all constants should be moved here (many are in Vocabulary).
- *
+ *
* @author Matt Post post@cs.jhu.edu
*/
public final class Constants {
- public static String defaultNT = "[X]";
+ public static final String defaultNT = "[X]";
public static final String START_SYM = "<s>";
public static final String STOP_SYM = "</s>";
public static final String UNKNOWN_WORD = "<unk>";
-
+
public static final String fieldDelimiter = "\\s\\|{3}\\s";
public static final String spaceSeparator = "\\s+";
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/FileUtility.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/FileUtility.java b/src/main/java/org/apache/joshua/util/FileUtility.java
index a36b07f..0f13e6a 100644
--- a/src/main/java/org/apache/joshua/util/FileUtility.java
+++ b/src/main/java/org/apache/joshua/util/FileUtility.java
@@ -18,38 +18,22 @@
*/
package org.apache.joshua.util;
-import java.io.BufferedReader;
import java.io.BufferedWriter;
-import java.io.Closeable;
import java.io.File;
import java.io.FileDescriptor;
-import java.io.FileInputStream;
import java.io.FileOutputStream;
-import java.io.FileReader;
import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
import java.io.OutputStreamWriter;
-import java.nio.charset.Charset;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Scanner;
+import java.nio.charset.StandardCharsets;
/**
* utility functions for file operations
- *
+ *
* @author Zhifei Li, zhifei.work@gmail.com
* @author wren ng thornton wren@users.sourceforge.net
* @since 28 February 2009
*/
public class FileUtility {
- public static String DEFAULT_ENCODING = "UTF-8";
-
- /*
- * Note: charset name is case-agnostic "UTF-8" is the canonical name "UTF8", "unicode-1-1-utf-8"
- * are aliases Java doesn't distinguish utf8 vs UTF-8 like Perl does
- */
- private static final Charset FILE_ENCODING = Charset.forName(DEFAULT_ENCODING);
/**
* Warning, will truncate/overwrite existing files
@@ -61,122 +45,7 @@ public class FileUtility {
return new BufferedWriter(new OutputStreamWriter(
// TODO: add GZIP
filename.equals("-") ? new FileOutputStream(FileDescriptor.out) : new FileOutputStream(
- filename, false), FILE_ENCODING));
- }
-
- /**
- * Recursively delete the specified file or directory.
- *
- * @param f File or directory to delete
- * @return <code>true</code> if the specified file or directory was deleted, <code>false</code>
- * otherwise
- */
- public static boolean deleteRecursively(File f) {
- if (null != f) {
- if (f.isDirectory())
- for (File child : f.listFiles())
- deleteRecursively(child);
- return f.delete();
- } else {
- return false;
- }
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes, overwriting the old file if present.
- *
- * @param data The integer array to write to disk.
- * @param filename The filename where the data should be written.
- * @throws IOException if there is a problem writing to the output file
- * @return the FileOutputStream on which the bytes were written
- */
- public static FileOutputStream writeBytes(int[] data, String filename) throws IOException {
- FileOutputStream out = new FileOutputStream(filename, false);
- writeBytes(data, out);
- return out;
- }
-
- /**
- * Writes data from the integer array to disk as raw bytes.
- *
- * @param data The integer array to write to disk.
- * @param out The output stream where the data should be written.
- * @throws IOException if there is a problem writing bytes
- */
- public static void writeBytes(int[] data, OutputStream out) throws IOException {
-
- byte[] b = new byte[4];
-
- for (int word : data) {
- b[0] = (byte) ((word >>> 24) & 0xFF);
- b[1] = (byte) ((word >>> 16) & 0xFF);
- b[2] = (byte) ((word >>> 8) & 0xFF);
- b[3] = (byte) ((word >>> 0) & 0xFF);
-
- out.write(b);
- }
- }
-
- public static void copyFile(String srFile, String dtFile) throws IOException {
- try {
- File f1 = new File(srFile);
- File f2 = new File(dtFile);
- copyFile(f1, f2);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public static void copyFile(File srFile, File dtFile) throws IOException {
- try {
-
- InputStream in = new FileInputStream(srFile);
-
- // For Append the file.
- // OutputStream out = new FileOutputStream(f2,true);
-
- // For Overwrite the file.
- OutputStream out = new FileOutputStream(dtFile);
-
- byte[] buf = new byte[1024];
- int len;
- while ((len = in.read(buf)) > 0) {
- out.write(buf, 0, len);
- }
- in.close();
- out.close();
- System.out.println("File copied.");
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- static public boolean deleteFile(String fileName) {
-
- File f = new File(fileName);
-
- // Make sure the file or directory exists and isn't write protected
- if (!f.exists())
- System.out.println("Delete: no such file or directory: " + fileName);
-
- if (!f.canWrite())
- System.out.println("Delete: write protected: " + fileName);
-
- // If it is a directory, make sure it is empty
- if (f.isDirectory()) {
- String[] files = f.list();
- if (files.length > 0)
- System.out.println("Delete: directory not empty: " + fileName);
- }
-
- // Attempt to delete it
- boolean success = f.delete();
-
- if (!success)
- System.out.println("Delete: deletion failed");
-
- return success;
-
+ filename, false), StandardCharsets.UTF_8));
}
/**
@@ -191,128 +60,4 @@ public class FileUtility {
return ".";
}
-
- public static void createFolderIfNotExisting(String folderName) {
- File f = new File(folderName);
- if (!f.isDirectory()) {
- System.out.println(" createFolderIfNotExisting -- Making directory: " + folderName);
- f.mkdirs();
- } else {
- System.out.println(" createFolderIfNotExisting -- Directory: " + folderName
- + " already existed");
- }
- }
-
- public static void closeCloseableIfNotNull(Closeable fileWriter) {
- if (fileWriter != null) {
- try {
- fileWriter.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }
-
- /**
- * Returns the directory were the program has been started,
- * the base directory you will implicitly get when specifying no
- * full path when e.g. opening a file
- * @return the current 'user.dir'
- */
- public static String getWorkingDirectory() {
- return System.getProperty("user.dir");
- }
-
- /**
- * Method to handle standard IO exceptions. catch (Exception e) {Utility.handleIO_exception(e);}
- * @param e an input {@link java.lang.Exception}
- */
- public static void handleExceptions(Exception e) {
- throw new RuntimeException(e);
- }
-
- /**
- * Convenience method to get a full file as a String
- * @param file the input {@link java.io.File}
- * @return The file as a String. Lines are separated by newline character.
- */
- public static String getFileAsString(File file) {
- String result = "";
- List<String> lines = getLines(file, true);
- for (int i = 0; i < lines.size() - 1; i++) {
- result += lines.get(i) + "\n";
- }
- if (!lines.isEmpty()) {
- result += lines.get(lines.size() - 1);
- }
- return result;
- }
-
- /**
- * This method returns a List of String. Each element of the list corresponds to a line from the
- * input file. The boolean keepDuplicates in the input determines if duplicate lines are allowed
- * in the output LinkedList or not.
- * @param file the input file
- * @param keepDuplicates whether to retain duplicate lines
- * @return a {@link java.util.List} of lines
- */
- static public List<String> getLines(File file, boolean keepDuplicates) {
- LinkedList<String> list = new LinkedList<String>();
- String line = "";
- try {
- BufferedReader InputReader = new BufferedReader(new FileReader(file));
- for (;;) { // this loop writes writes in a Sting each sentence of
- // the file and process it
- int current = InputReader.read();
- if (current == -1 || current == '\n') {
- if (keepDuplicates || !list.contains(line))
- list.add(line);
- line = "";
- if (current == -1)
- break; // EOF
- } else
- line += (char) current;
- }
- InputReader.close();
- } catch (Exception e) {
- handleExceptions(e);
- }
- return list;
- }
-
- /**
- * Returns a Scanner of the inputFile using a specific encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @param encoding the encoding to use within the Scanner
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile, String encoding) {
- Scanner scan = null;
- try {
- scan = new Scanner(inputFile, encoding);
- } catch (IOException e) {
- FileUtility.handleExceptions(e);
- }
- return scan;
- }
-
- /**
- * Returns a Scanner of the inputFile using default encoding
- *
- * @param inputFile the file for which to get a {@link java.util.Scanner} object
- * @return a {@link java.util.Scanner} object for a given file
- */
- public static Scanner getScanner(File inputFile) {
- return getScanner(inputFile, DEFAULT_ENCODING);
- }
-
- static public String getFirstLineInFile(File inputFile) {
- Scanner scan = FileUtility.getScanner(inputFile);
- if (!scan.hasNextLine())
- return null;
- String line = scan.nextLine();
- scan.close();
- return line;
- }
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/IntegerPair.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/IntegerPair.java b/src/main/java/org/apache/joshua/util/IntegerPair.java
deleted file mode 100644
index bfbfa23..0000000
--- a/src/main/java/org/apache/joshua/util/IntegerPair.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-/**
- * Memory-efficient implementation of an integer tuple.
- *
- * @author Lane Schwartz
- */
-public final class IntegerPair {
-
- public final int first;
- public final int second;
-
- public IntegerPair(final int first, final int second) {
- this.first = first;
- this.second = second;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/ListUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/ListUtil.java b/src/main/java/org/apache/joshua/util/ListUtil.java
index afb5af1..14154e8 100644
--- a/src/main/java/org/apache/joshua/util/ListUtil.java
+++ b/src/main/java/org/apache/joshua/util/ListUtil.java
@@ -22,55 +22,6 @@ import java.util.List;
public class ListUtil {
- /**
- * Static method to generate a list representation for an ArrayList of Strings S1,...,Sn
- *
- * @param list A list of Strings
- * @return A String consisting of the original list of strings concatenated and separated by
- * commas, and enclosed by square brackets i.e. '[S1,S2,...,Sn]'
- */
- public static String stringListString(List<String> list) {
-
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
-
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
-
- result += "]";
-
- return result;
-
- }
-
- public static <E> String objectListString(List<E> list) {
- String result = "[";
- for (int i = 0; i < list.size() - 1; i++) {
- result += list.get(i) + ",";
- }
- if (list.size() > 0) {
- // get the generated word for the last target position
- result += list.get(list.size() - 1);
- }
- result += "]";
- return result;
- }
-
- /**
- * Static method to generate a simple concatenated representation for an ArrayList of Strings
- * S1,...,Sn
- *
- * @param list A list of Strings
- * @return todo
- */
- public static String stringListStringWithoutBrackets(List<String> list) {
- return stringListStringWithoutBracketsWithSpecifiedSeparator(list, " ");
- }
-
public static String stringListStringWithoutBracketsCommaSeparated(List<String> list) {
return stringListStringWithoutBracketsWithSpecifiedSeparator(list, ",");
}
@@ -89,7 +40,5 @@ public class ListUtil {
}
return result;
-
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/Lists.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/Lists.java b/src/main/java/org/apache/joshua/util/Lists.java
deleted file mode 100644
index d62d1aa..0000000
--- a/src/main/java/org/apache/joshua/util/Lists.java
+++ /dev/null
@@ -1,567 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.Iterator;
-import java.util.NoSuchElementException;
-
-/**
- *
- *
- * @author Lane Schwartz
- */
-public class Lists {
-
- // public static void main(String[] args) {
- //
- // int[] list = {100, 200, 300, 400, 500};
- //
- // for (IndexedInt i : eachWithIndex(list)) {
- //
- // System.out.println(i.getIndex() + " " + i.getValue());
- //
- // }
- //
- // Integer[] list2 = new Integer[]{10, 20, 30, 40};
- // for (Indexed<Integer> i : eachWithIndex(list2)) {
- //
- // System.out.println(i.getIndex() + " " + i.getValue());
- //
- // }
- //
- // java.util.List<Integer> list3 = new java.util.ArrayList<Integer>();
- // for (int i : list2) { list3.add(i); }
- //
- // for (Indexed<Integer> i : eachWithIndex(list3)) {
- //
- // System.out.println(i.getIndex() + " " + i.getValue());
- //
- // }
- // }
-
-
- public static Iterable<Integer> upto(final int exclusiveUpperBound) {
- return new Iterable<Integer>() {
- public Iterator<Integer> iterator() {
- return new Iterator<Integer>() {
- int next = 0;
-
- public boolean hasNext() {
- return next < exclusiveUpperBound;
- }
-
- public Integer next() {
- if (!hasNext()) {
- throw new NoSuchElementException();
- }
- int result = next;
- next += 1;
- return result;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedByte> eachWithIndex(final byte[] list) {
-
- return new Iterable<IndexedByte>() {
-
- public Iterator<IndexedByte> iterator() {
- return new Iterator<IndexedByte>() {
-
- int nextIndex = -1;
- IndexedByte indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedByte next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedByte(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedShort> eachWithIndex(final short[] list) {
-
- return new Iterable<IndexedShort>() {
-
- public Iterator<IndexedShort> iterator() {
- return new Iterator<IndexedShort>() {
-
- int nextIndex = -1;
- IndexedShort indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedShort next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedShort(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedInt> eachWithIndex(final int[] list) {
-
- return new Iterable<IndexedInt>() {
-
- public Iterator<IndexedInt> iterator() {
- return new Iterator<IndexedInt>() {
-
- int nextIndex = -1;
- IndexedInt indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedInt next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedInt(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedLong> eachWithIndex(final long[] list) {
-
- return new Iterable<IndexedLong>() {
-
- public Iterator<IndexedLong> iterator() {
- return new Iterator<IndexedLong>() {
-
- int nextIndex = -1;
- IndexedLong indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedLong next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedLong(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedFloat> eachWithIndex(final float[] list) {
-
- return new Iterable<IndexedFloat>() {
-
- public Iterator<IndexedFloat> iterator() {
- return new Iterator<IndexedFloat>() {
-
- int nextIndex = -1;
- IndexedFloat indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedFloat next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedFloat(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static Iterable<IndexedDouble> eachWithIndex(final double[] list) {
-
- return new Iterable<IndexedDouble>() {
-
- public Iterator<IndexedDouble> iterator() {
- return new Iterator<IndexedDouble>() {
-
- int nextIndex = -1;
- IndexedDouble indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public IndexedDouble next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new IndexedDouble(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static <V> Iterable<Indexed<V>> eachWithIndex(final V[] list) {
- return new Iterable<Indexed<V>>() {
-
- public Iterator<Indexed<V>> iterator() {
- return new Iterator<Indexed<V>>() {
-
- int nextIndex = -1;
- Indexed<V> indexedValue;
-
- public boolean hasNext() {
- return (nextIndex < list.length);
- }
-
- public Indexed<V> next() {
- if (nextIndex >= list.length) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new Indexed<V>(list[nextIndex], nextIndex);
- } else {
- indexedValue.value = list[nextIndex];
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static <V> Iterable<Indexed<V>> eachWithIndex(final Iterator<V> valueIterator) {
- return new Iterable<Indexed<V>>() {
-
- public Iterator<Indexed<V>> iterator() {
- return new Iterator<Indexed<V>>() {
-
- int nextIndex = -1;
- Indexed<V> indexedValue;
-
- public boolean hasNext() {
- return valueIterator.hasNext();
- }
-
- public Indexed<V> next() {
- if (!valueIterator.hasNext()) {
- throw new NoSuchElementException();
- } else if (nextIndex < 0) {
- nextIndex = 0;
- indexedValue = new Indexed<V>(valueIterator.next(), nextIndex);
- } else {
- indexedValue.value = valueIterator.next();
- indexedValue.index = nextIndex;
- }
-
- nextIndex += 1;
- return indexedValue;
- }
-
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
- };
- }
-
- };
- }
-
- public static <V> Iterable<Indexed<V>> eachWithIndex(final Iterable<V> iterable) {
- return eachWithIndex(iterable.iterator());
- }
-
-
- public static class Index {
-
- int index;
-
- Index(int index) {
- this.index = index;
- }
-
- public int getIndex() {
- return this.index;
- }
-
- void setIndex(int index) {
- this.index = index;
- }
- }
-
- public static class IndexedBoolean extends Index {
-
- boolean value;
-
- IndexedBoolean(boolean value, int index) {
- super(index);
- this.value = value;
- }
-
- public boolean getValue() {
- return this.value;
- }
-
- void setValue(boolean value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedByte extends Index {
-
- byte value;
-
- IndexedByte(byte value, int index) {
- super(index);
- this.value = value;
- }
-
- public byte getValue() {
- return this.value;
- }
-
- void setValue(byte value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedShort extends Index {
-
- short value;
-
- IndexedShort(short value, int index) {
- super(index);
- this.value = value;
- }
-
- public short getValue() {
- return this.value;
- }
-
- void setValue(short value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedInt extends Index {
-
- int value;
-
- IndexedInt(int value, int index) {
- super(index);
- this.value = value;
- }
-
- public int getValue() {
- return this.value;
- }
-
- void setValue(int value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedLong extends Index {
-
- long value;
-
- IndexedLong(long value, int index) {
- super(index);
- this.value = value;
- }
-
- public long getValue() {
- return this.value;
- }
-
- void setValue(long value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedFloat extends Index {
-
- float value;
-
- IndexedFloat(float value, int index) {
- super(index);
- this.value = value;
- }
-
- public float getValue() {
- return this.value;
- }
-
- void setValue(float value) {
- this.value = value;
- this.index += 1;
- }
- }
-
- public static class IndexedDouble extends Index {
-
- double value;
-
- IndexedDouble(double value, int index) {
- super(index);
- this.value = value;
- }
-
- public double getValue() {
- return this.value;
- }
-
- void setValue(double value) {
- this.value = value;
- this.index += 1;
- }
- }
-
-
- public static class Indexed<V> extends Index {
-
- V value;
-
- Indexed(V value, int index) {
- super(index);
- this.value = value;
- }
-
- public V getValue() {
- return this.value;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/NullIterator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/NullIterator.java b/src/main/java/org/apache/joshua/util/NullIterator.java
deleted file mode 100644
index c6e4b46..0000000
--- a/src/main/java/org/apache/joshua/util/NullIterator.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.Iterator;
-import java.util.NoSuchElementException;
-
-
-/**
- * This class provides a null-object Iterator. That is, an iterator over an empty collection.
- *
- * @author wren ng thornton wren@users.sourceforge.net
- * @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
- */
-public class NullIterator<E> implements Iterable<E>, Iterator<E> {
-
- // ===============================================================
- // Iterable -- for foreach loops, because sometimes Java can be very stupid
- // ===============================================================
-
- /**
- * Return self as an iterator. We restrict the return type because some code is written to accept
- * both Iterable and Iterator, and the fact that we are both confuses Java. So this is just an
- * upcast, but more succinct to type.
- */
- public Iterator<E> iterator() {
- return this;
- }
-
-
- // ===============================================================
- // Iterator
- // ===============================================================
-
- /** Always returns false. */
- public boolean hasNext() {
- return false;
- }
-
- /** Always throws {@link NoSuchElementException}. */
- public E next() throws NoSuchElementException {
- throw new NoSuchElementException();
- }
-
- /** Unsupported. */
- public void remove() throws UnsupportedOperationException {
- throw new UnsupportedOperationException();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/QuietFormatter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/QuietFormatter.java b/src/main/java/org/apache/joshua/util/QuietFormatter.java
deleted file mode 100644
index 7220080..0000000
--- a/src/main/java/org/apache/joshua/util/QuietFormatter.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.logging.Formatter;
-import java.util.logging.LogRecord;
-
-/**
- * Log formatter that prints just the message, with no time stamp.
- *
- * @author Lane Schwartz
- * @version $LastChangedDate$
- */
-public class QuietFormatter extends Formatter {
-
- public String format(LogRecord record) {
- return "" + formatMessage(record) + "\n";
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/ReverseOrder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/ReverseOrder.java b/src/main/java/org/apache/joshua/util/ReverseOrder.java
deleted file mode 100644
index 0270036..0000000
--- a/src/main/java/org/apache/joshua/util/ReverseOrder.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.Comparator;
-
-/**
- * ReverseOrder is a Comparator that reverses the natural order of Comparable objects.
- *
- * @author Chris Callison-Burch
- * @since 2 June 2008
- */
-public class ReverseOrder<K extends Comparable<K>> implements Comparator<K> {
-
- public int compare(K obj1, K obj2) {
- int comparison = obj1.compareTo(obj2);
- if (comparison != 0) {
- comparison = comparison * -1;
- }
- return comparison;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/SampledList.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/SampledList.java b/src/main/java/org/apache/joshua/util/SampledList.java
deleted file mode 100644
index 60b0ef9..0000000
--- a/src/main/java/org/apache/joshua/util/SampledList.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.util.AbstractList;
-import java.util.List;
-
-/**
- * List that performs sampling at specified intervals.
- *
- * @author Lane Schwartz
- * @version $LastChangedDate$
- */
-public class SampledList<E> extends AbstractList<E> implements List<E> {
-
- private final List<E> list;
- private final int size;
- private final int stepSize;
-
- /**
- * Constructs a sampled list backed by a provided list.
- * <p>
- * The maximum size of this list will be no greater than the provided sample size.
- *
- * @param list List from which to sample.
- * @param sampleSize Maximum number of items to include in the new sampled list.
- */
- public SampledList(List<E> list, int sampleSize) {
- this.list = list;
-
- int listSize = list.size();
-
- if (listSize <= sampleSize) {
- this.size = listSize;
- this.stepSize = 1;
- } else {
- this.size = sampleSize;
- this.stepSize = listSize / sampleSize;
- }
-
- }
-
- @Override
- public E get(int index) {
- return list.get(index * stepSize);
- }
-
- @Override
- public int size() {
- return size;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/SocketUtility.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/SocketUtility.java b/src/main/java/org/apache/joshua/util/SocketUtility.java
deleted file mode 100644
index 134fd35..0000000
--- a/src/main/java/org/apache/joshua/util/SocketUtility.java
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util;
-
-import java.io.BufferedReader;
-import java.io.DataInputStream;
-import java.io.IOException;
-import java.io.InputStreamReader;
-import java.io.OutputStreamWriter;
-import java.io.PrintWriter;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.Socket;
-import java.net.SocketAddress;
-
-/**
- *
- * @author Zhifei Li, zhifei.work@gmail.com
- * @version $LastChangedDate$
- */
-public class SocketUtility {
-
- // ############# client side #########
- // connect to server
- public static ClientConnection open_connection_client(String hostname, int port) {
- ClientConnection res = new ClientConnection();
- // TODO: remove from class
- // res.hostname = hostname;
- // res.port = port;
- try {
- InetAddress addr = InetAddress.getByName(hostname);
- SocketAddress sockaddr = new InetSocketAddress(addr, port);
-
- res.socket = new Socket(); // Create an unbound socket
- // This method will block no more than timeoutMs If the timeout occurs, SocketTimeoutException
- // is thrown.
- int timeoutMs = 3000; // 2 seconds
- res.socket.connect(sockaddr, timeoutMs);
- res.socket.setKeepAlive(true);
- // file
- res.in = new BufferedReader(new InputStreamReader(res.socket.getInputStream()));
- res.out = new PrintWriter(new OutputStreamWriter(res.socket.getOutputStream()));
-
- // TODO: for debugging, but should be removed
- // res.data_in = new DataInputStream(new BufferedInputStream( res.socket.getInputStream()));
- // res.data_out = new DataOutputStream(new BufferedOutputStream
- // (res.socket.getOutputStream()));
-
- } catch ( IOException e) {
- throw new RuntimeException(e);
- }
- return res;
- }
-
-
- public static class ClientConnection {
- // TODO: These are never read from, so we're planning to remove them
- // String hostname; // server name
- // int port; // server port
- Socket socket;
- public BufferedReader in;
- public PrintWriter out;
-
- // TODO: for debugging, but should be removed
- // public DataOutputStream data_out;
- // public DataInputStream data_in;
-
- public String exe_request(String line_out) {
- String line_res = null;
- try {
- out.println(line_out);
- out.flush();
- line_res = in.readLine(); // TODO block function, big bug, the server may close the section
- // (e.g., the server thread is dead due to out of memory(which is
- // possible due to cache) )
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
- return line_res;
- }
-
- public void write_line(String line_out) {
- out.println(line_out);
- out.flush();
- }
-
- public void write_int(int line_out) {
- out.println(line_out);
- out.flush();
- }
-
- public String read_line() {
- String line_res = null;
- try {
- line_res = in.readLine(); // TODO block function, big bug, the server may close the section
- // (e.g., the server thread is dead due to out of memory(which is
- // possible due to cache) )
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
- return line_res;
- }
-
-
- public void close() {
- try {
- socket.close();
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
- }
-
- public static double readDoubleLittleEndian(DataInputStream d_in) {
- long accum = 0;
- try {
- for (int shiftBy = 0; shiftBy < 64; shiftBy += 8) {
- // must cast to long or shift done modulo 32
- accum |= ((long) (d_in.readByte() & 0xff)) << shiftBy;
- }
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
-
- return Double.longBitsToDouble(accum);
- // there is no such method as Double.reverseBytes(d);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/encoding/Analyzer.java b/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
index ad2910c..d9bab66 100644
--- a/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
+++ b/src/main/java/org/apache/joshua/util/encoding/Analyzer.java
@@ -87,7 +87,7 @@ public class Analyzer {
// If the count is not 0, i.e. there were negative values, we should
// not bucket them with the positive ones. Close out the bucket now.
if (count != 0 && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
+ buckets[index++] = sum / count;
count = 0;
sum = 0;
}
@@ -98,15 +98,15 @@ public class Analyzer {
sum += key * value;
// Check if the bucket is full.
if (count >= size && index < buckets.length - 2) {
- buckets[index++] = (float) sum / count;
+ buckets[index++] = sum / count;
count = 0;
sum = 0;
}
last_key = key;
}
if (count > 0 && index < buckets.length - 1)
- buckets[index++] = (float) sum / count;
-
+ buckets[index++] = sum / count;
+
float[] shortened = new float[index];
for (int i = 0; i < shortened.length; ++i)
shortened[i] = buckets[i];
@@ -167,7 +167,7 @@ public class Analyzer {
return PrimitiveFloatEncoder.INT;
return PrimitiveFloatEncoder.FLOAT;
}
-
+
public FloatEncoder inferType(int bits) {
if (isBoolean())
return PrimitiveFloatEncoder.BOOLEAN;
@@ -191,45 +191,46 @@ public class Analyzer {
sb.append(label + "\t" + String.format("%.5f", val) + "\t" + histogram.get(val) + "\n");
return sb.toString();
}
-
+
public static void main(String[] args) throws IOException {
- LineReader reader = new LineReader(args[0]);
- ArrayList<Float> s = new ArrayList<Float>();
-
- System.out.println("Initialized.");
- while (reader.hasNext())
- s.add(Float.parseFloat(reader.next().trim()));
- System.out.println("Data read.");
- int n = s.size();
- byte[] c = new byte[n];
- ByteBuffer b = ByteBuffer.wrap(c);
- Analyzer q = new Analyzer();
-
- q.initialize();
- for (int i = 0; i < n; i++)
- q.add(s.get(i));
- EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
- System.out.println("Quantizer learned.");
-
- for (int i = 0; i < n; i++)
- eb.write(b, s.get(i));
- b.rewind();
- System.out.println("Quantization complete.");
-
- float avg_error = 0;
- float error = 0;
- int count = 0;
- for (int i = -4; i < n - 4; i++) {
- float coded = eb.read(b, i);
- if (s.get(i + 4) != 0) {
- error = Math.abs(s.get(i + 4) - coded);
- avg_error += error;
- count++;
+ try (LineReader reader = new LineReader(args[0]);) {
+ ArrayList<Float> s = new ArrayList<Float>();
+
+ System.out.println("Initialized.");
+ while (reader.hasNext())
+ s.add(Float.parseFloat(reader.next().trim()));
+ System.out.println("Data read.");
+ int n = s.size();
+ byte[] c = new byte[n];
+ ByteBuffer b = ByteBuffer.wrap(c);
+ Analyzer q = new Analyzer();
+
+ q.initialize();
+ for (int i = 0; i < n; i++)
+ q.add(s.get(i));
+ EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
+ System.out.println("Quantizer learned.");
+
+ for (int i = 0; i < n; i++)
+ eb.write(b, s.get(i));
+ b.rewind();
+ System.out.println("Quantization complete.");
+
+ float avg_error = 0;
+ float error = 0;
+ int count = 0;
+ for (int i = -4; i < n - 4; i++) {
+ float coded = eb.read(b, i);
+ if (s.get(i + 4) != 0) {
+ error = Math.abs(s.get(i + 4) - coded);
+ avg_error += error;
+ count++;
+ }
}
- }
- avg_error /= count;
- System.out.println("Evaluation complete.");
+ avg_error /= count;
+ System.out.println("Evaluation complete.");
- System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
+ System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java b/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
index 504859f..d485ea5 100644
--- a/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
+++ b/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
@@ -62,26 +62,27 @@ public class FeatureTypeAnalyzer {
}
public void readConfig(String config_filename) throws IOException {
- LineReader reader = new LineReader(config_filename);
- while (reader.hasNext()) {
- // Clean up line, chop comments off and skip if the result is empty.
- String line = reader.next().trim();
- if (line.indexOf('#') != -1)
- line = line.substring(0, line.indexOf('#'));
- if (line.isEmpty())
- continue;
- String[] fields = line.split("[\\s]+");
-
- if ("encoder".equals(fields[0])) {
- // Adding an encoder to the mix.
- if (fields.length < 3) {
- throw new RuntimeException("Incomplete encoder line in config.");
+ try(LineReader reader = new LineReader(config_filename);) {
+ while (reader.hasNext()) {
+ // Clean up line, chop comments off and skip if the result is empty.
+ String line = reader.next().trim();
+ if (line.indexOf('#') != -1)
+ line = line.substring(0, line.indexOf('#'));
+ if (line.isEmpty())
+ continue;
+ String[] fields = line.split("[\\s]+");
+
+ if ("encoder".equals(fields[0])) {
+ // Adding an encoder to the mix.
+ if (fields.length < 3) {
+ throw new RuntimeException("Incomplete encoder line in config.");
+ }
+ String encoder_key = fields[1];
+ List<Integer> feature_ids = new ArrayList<Integer>();
+ for (int i = 2; i < fields.length; i++)
+ feature_ids.add(Vocabulary.id(fields[i]));
+ addFeatures(encoder_key, feature_ids);
}
- String encoder_key = fields[1];
- ArrayList<Integer> feature_ids = new ArrayList<Integer>();
- for (int i = 2; i < fields.length; i++)
- feature_ids.add(Vocabulary.id(fields[i]));
- addFeatures(encoder_key, feature_ids);
}
}
}
@@ -182,6 +183,7 @@ public class FeatureTypeAnalyzer {
out_stream.close();
}
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int feature_id : featureToType.keySet()) {
@@ -198,7 +200,7 @@ public class FeatureTypeAnalyzer {
this.labeled = labeled;
}
- class FeatureType {
+ static class FeatureType {
FloatEncoder encoder;
Analyzer analyzer;
int bits;
@@ -236,6 +238,7 @@ public class FeatureTypeAnalyzer {
analyzer.add(value);
}
+ @Override
public boolean equals(Object t) {
if (t != null && t instanceof FeatureType) {
FeatureType that = (FeatureType) t;
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java b/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
new file mode 100644
index 0000000..42dd236
--- /dev/null
+++ b/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.io;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.function.Predicate;
+import java.util.stream.Stream;
+
+/**
+ * A class that represents a {@link StandardCharsets#UTF_8} text file. Will
+ * throw a {@link FileNotFoundException} upon instantiation if the underlying
+ * {@link Path}, or {@link String} representing a Path, is not found.
+ */
+public class ExistingUTF8EncodedTextFile {
+ private static final Predicate<String> emptyStringPredicate = s -> s.isEmpty();
+
+ private final Path p;
+
+ public ExistingUTF8EncodedTextFile(String pathStr) throws FileNotFoundException {
+ this(Paths.get(pathStr));
+ }
+
+ public ExistingUTF8EncodedTextFile(Path p) throws FileNotFoundException {
+ this.p = p;
+ if (!Files.exists(p))
+ throw new FileNotFoundException("Did not find the file at path: " + p.toString());
+ }
+
+ /**
+ * @return the {@link Path} representing this object
+ */
+ public Path getPath() {
+ return this.p;
+ }
+
+ /**
+ * @return the number of lines in the file represented by this object
+ * @throws IOException on inability to read file (maybe it's not a text file)
+ */
+ public int getNumberOfLines() throws IOException {
+ try(Stream<String> ls = Files.lines(this.p, StandardCharsets.UTF_8);) {
+ return (int) ls.count();
+ }
+ }
+
+ /**
+ * @return the number of non-empty lines in the file represented by this object
+ * @throws IOException on inability to read file (maybe it's not a text file)
+ */
+ public int getNumberOfNonEmptyLines() throws IOException {
+ try(Stream<String> ls = Files.lines(this.p, StandardCharsets.UTF_8);) {
+ return (int) ls.filter(emptyStringPredicate.negate())
+ .count();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/io/IndexedReader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/io/IndexedReader.java b/src/main/java/org/apache/joshua/util/io/IndexedReader.java
index f357e55..d206544 100644
--- a/src/main/java/org/apache/joshua/util/io/IndexedReader.java
+++ b/src/main/java/org/apache/joshua/util/io/IndexedReader.java
@@ -25,12 +25,11 @@ import java.util.NoSuchElementException;
/**
* Wraps a reader with "line" index information.
- *
+ *
* @author wren ng thornton wren@users.sourceforge.net
* @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
*/
public class IndexedReader<E> implements Reader<E> {
-
/** A name for the type of elements the reader produces. */
private final String elementName;
@@ -46,7 +45,7 @@ public class IndexedReader<E> implements Reader<E> {
this.reader = reader;
}
- /**
+ /**
* Return the number of elements delivered so far.
* @return integer representing the number of elements delivered so far
*/
@@ -72,7 +71,7 @@ public class IndexedReader<E> implements Reader<E> {
// Reader
// ===============================================================
- /**
+ /**
* Delegated to the underlying reader.
* @return true if the reader is ready
* @throws IOException if there is an error determining readiness
@@ -92,6 +91,7 @@ public class IndexedReader<E> implements Reader<E> {
* however, when we fall out of scope, the underlying reader will too, so its finalizer may be
* called. For correctness, be sure to manually close all readers.
*/
+ @Override
public void close() throws IOException {
try {
this.reader.close();
@@ -102,6 +102,7 @@ public class IndexedReader<E> implements Reader<E> {
/** Delegated to the underlying reader. */
+ @Override
public E readLine() throws IOException {
E line;
try {
@@ -119,6 +120,7 @@ public class IndexedReader<E> implements Reader<E> {
// ===============================================================
/** Return self as an iterator. */
+ @Override
public Iterator<E> iterator() {
return this;
}
@@ -129,12 +131,14 @@ public class IndexedReader<E> implements Reader<E> {
// ===============================================================
/** Delegated to the underlying reader. */
+ @Override
public boolean hasNext() {
return this.reader.hasNext();
}
/** Delegated to the underlying reader. */
+ @Override
public E next() throws NoSuchElementException {
E line = this.reader.next();
// Let exceptions out, we'll wrap any errors a closing time.
@@ -149,6 +153,7 @@ public class IndexedReader<E> implements Reader<E> {
* returns the number of elements delivered to the client, so removing an element from the
* underlying collection does not affect that number.
*/
+ @Override
public void remove() throws UnsupportedOperationException {
this.reader.remove();
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/io/LineReader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/io/LineReader.java b/src/main/java/org/apache/joshua/util/io/LineReader.java
index 5122994..09e22c2 100644
--- a/src/main/java/org/apache/joshua/util/io/LineReader.java
+++ b/src/main/java/org/apache/joshua/util/io/LineReader.java
@@ -19,13 +19,13 @@
package org.apache.joshua.util.io;
import java.io.BufferedReader;
+import java.io.File;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
-import java.io.File;
-import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.zip.GZIPInputStream;
@@ -35,19 +35,13 @@ import org.apache.joshua.decoder.Decoder;
/**
* This class provides an Iterator interface to a BufferedReader. This covers the most common
* use-cases for reading from files without ugly code to check whether we got a line or not.
- *
+ *
* @author wren ng thornton wren@users.sourceforge.net
* @author Matt Post post@cs.jhu.edu
*/
public class LineReader implements Reader<String> {
/*
- * Note: charset name is case-agnostic "UTF-8" is the canonical name "UTF8", "unicode-1-1-utf-8"
- * are aliases Java doesn't distinguish utf8 vs UTF-8 like Perl does
- */
- private static final Charset FILE_ENCODING = Charset.forName("UTF-8");
-
- /*
* The reader and its underlying input stream. We need to keep a hold of the underlying
* input stream so that we can query how many raw bytes it's read (for a generic progress
* meter that works across GZIP'ed and plain text files).
@@ -59,9 +53,9 @@ public class LineReader implements Reader<String> {
private IOException error;
private int lineno = 0;
-
+
private boolean display_progress = false;
-
+
private int progress = 0;
// ===============================================================
@@ -71,17 +65,17 @@ public class LineReader implements Reader<String> {
/**
* Opens a file for iterating line by line. The special "-" filename can be used to specify
* STDIN. GZIP'd files are tested for automatically.
- *
+ *
* @param filename the file to be opened ("-" for STDIN)
* @throws IOException if there is an error reading the input file
*/
public LineReader(String filename) throws IOException {
-
+
display_progress = (Decoder.VERBOSE >= 1);
-
+
progress = 0;
-
- InputStream stream = null;
+
+ InputStream stream = null;
long totalBytes = -1;
if (filename.equals("-")) {
rawStream = null;
@@ -97,11 +91,11 @@ public class LineReader implements Reader<String> {
rawStream.close();
stream = rawStream = new ProgressInputStream(new FileInputStream(filename), totalBytes);
}
- }
-
- this.reader = new BufferedReader(new InputStreamReader(stream, FILE_ENCODING));
+ }
+
+ this.reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8));
}
-
+
public LineReader(String filename, boolean show_progress) throws IOException {
this(filename);
display_progress = (Decoder.VERBOSE >= 1 && show_progress);
@@ -113,19 +107,19 @@ public class LineReader implements Reader<String> {
* @param in an {@link java.io.InputStream} to wrap and iterate over line by line
*/
public LineReader(InputStream in) {
- this.reader = new BufferedReader(new InputStreamReader(in, FILE_ENCODING));
+ this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8));
display_progress = false;
}
-
+
/**
- * Chain to the underlying {@link ProgressInputStream}.
- *
+ * Chain to the underlying {@link ProgressInputStream}.
+ *
* @return an integer from 0..100, indicating how much of the file has been read.
*/
public int progress() {
return rawStream == null ? 0 : rawStream.progress();
}
-
+
/**
* This method will close the file handle, and will raise any exceptions that occured during
* iteration. The method is idempotent, and all calls after the first are no-ops (unless the
@@ -133,6 +127,7 @@ public class LineReader implements Reader<String> {
* object falls out of scope.
* @throws IOException if there is an error closing the file handler
*/
+ @Override
public void close() throws IOException {
this.buffer = null; // Just in case it's a large string
@@ -161,42 +156,13 @@ public class LineReader implements Reader<String> {
}
}
-
- /**
- * We attempt to avoid leaking file descriptors if you fail to call close before the object falls
- * out of scope. However, the language spec makes <b>no guarantees</b> about timeliness of garbage
- * collection. It is a bug to rely on this method to release the resources. Also, the garbage
- * collector will discard any exceptions that have queued up, without notifying the application in
- * any way.
- *
- * Having a finalizer means the JVM can't do "fast allocation" of LineReader objects (or
- * subclasses). This isn't too important due to disk latency, but may be worth noting.
- *
- * @see <a
- * href="http://java2go.blogspot.com/2007/09/javaone-2007-performance-tips-2-finish.html">Performance
- * Tips</a>
- * @see <a
- * href="http://www.javaworld.com/javaworld/jw-06-1998/jw-06-techniques.html?page=1">Techniques</a>
- */
- protected void finalize() throws Throwable {
- try {
- this.close();
- } catch (IOException e) {
- // Do nothing. The GC will discard the exception
- // anyways, but it may cause us to linger on the heap.
- } finally {
- super.finalize();
- }
- }
-
-
-
// ===============================================================
// Reader
// ===============================================================
// Copied from interface documentation.
/** Determine if the reader is ready to read a line. */
+ @Override
public boolean ready() throws IOException {
return this.reader.ready();
}
@@ -206,6 +172,7 @@ public class LineReader implements Reader<String> {
* This method is like next() except that it throws the IOException directly. If there are no
* lines to be read then null is returned.
*/
+ @Override
public String readLine() throws IOException {
if (this.hasNext()) {
String line = this.buffer;
@@ -228,6 +195,7 @@ public class LineReader implements Reader<String> {
// ===============================================================
/** Return self as an iterator. */
+ @Override
public Iterator<String> iterator() {
return this;
}
@@ -243,6 +211,7 @@ public class LineReader implements Reader<String> {
* <code>true</code> if <code>next</code> would return an element rather than throwing an
* exception.)
*/
+ @Override
public boolean hasNext() {
if (null != this.buffer) {
return true;
@@ -269,12 +238,13 @@ public class LineReader implements Reader<String> {
* The actual IOException encountered will be thrown later, when the LineReader is closed. Also if
* there is no line to be read then NoSuchElementException is thrown.
*/
+ @Override
public String next() throws NoSuchElementException {
if (this.hasNext()) {
if (display_progress) {
int newProgress = (reader != null) ? progress() : 100;
// System.err.println(String.format("OLD %d NEW %d", progress, newProgress));
-
+
if (newProgress > progress) {
for (int i = progress + 1; i <= newProgress; i++)
if (i == 97) {
@@ -297,7 +267,7 @@ public class LineReader implements Reader<String> {
progress = newProgress;
}
}
-
+
String line = this.buffer;
this.lineno++;
this.buffer = null;
@@ -306,39 +276,19 @@ public class LineReader implements Reader<String> {
throw new NoSuchElementException();
}
}
-
+
/* Get the line number of the last line that was returned */
public int lineno() {
return this.lineno;
}
/** Unsupported. */
+ @Override
public void remove() throws UnsupportedOperationException {
throw new UnsupportedOperationException();
}
-
/**
- * Iterates over all lines, ignoring their contents, and returns the count of lines. If some lines
- * have already been read, this will return the count of remaining lines. Because no lines will
- * remain after calling this method, we implicitly call close.
- *
- * @return the number of lines read
- * @throws IOException if there is an error reading lines
- */
- public int countLines() throws IOException {
- int lines = 0;
-
- while (this.hasNext()) {
- this.next();
- lines++;
- }
- this.close();
-
- return lines;
- }
-
- /**
* Example usage code.
* @param args an input file
*/
@@ -348,19 +298,10 @@ public class LineReader implements Reader<String> {
System.exit(1);
}
- try {
-
- LineReader in = new LineReader(args[0]);
- try {
- for (String line : in) {
-
- System.out.println(line);
-
- }
- } finally {
- in.close();
+ try (LineReader in = new LineReader(args[0]);) {
+ for (String line : in) {
+ System.out.println(line);
}
-
} catch (IOException e) {
e.printStackTrace();
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/io/NullReader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/io/NullReader.java b/src/main/java/org/apache/joshua/util/io/NullReader.java
deleted file mode 100644
index f833f00..0000000
--- a/src/main/java/org/apache/joshua/util/io/NullReader.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.joshua.util.io;
-
-import java.io.IOException;
-
-import org.apache.joshua.util.NullIterator;
-
-
-/**
- * This class provides a null-object Reader. This is primarily useful for when you may or may not
- * have a {@link Reader}, and you don't want to check for null all the time. All operations are
- * no-ops.
- *
- * @author wren ng thornton wren@users.sourceforge.net
- * @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
- */
-public class NullReader<E> extends NullIterator<E> implements Reader<E> {
-
- // ===============================================================
- // Constructors and destructors
- // ===============================================================
-
- // TODO: use static factory method and singleton?
- public NullReader() {}
-
- /** A no-op. */
- public void close() throws IOException {}
-
-
- // ===============================================================
- // Reader
- // ===============================================================
-
- /**
- * Always returns true. Is this correct? What are the semantics of ready()? We're always capable
- * of delivering nothing, but we're never capable of delivering anything...
- */
- public boolean ready() {
- return true;
- }
-
- /** Always returns null. */
- public E readLine() throws IOException {
- return null;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/io/Reader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/io/Reader.java b/src/main/java/org/apache/joshua/util/io/Reader.java
index cab6d74..e3a150e 100644
--- a/src/main/java/org/apache/joshua/util/io/Reader.java
+++ b/src/main/java/org/apache/joshua/util/io/Reader.java
@@ -23,26 +23,27 @@ import java.util.Iterator;
/**
* Common interface for Reader type objects.
- *
+ *
* @author wren ng thornton wren@users.sourceforge.net
* @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
*/
-public interface Reader<E> extends Iterable<E>, Iterator<E> {
+public interface Reader<E> extends Iterable<E>, Iterator<E>, AutoCloseable {
- /**
+ /**
* Close the reader, freeing all resources.
* @throws IOException if there is an error closing the reader instance
*/
+ @Override
void close() throws IOException;
- /**
+ /**
* Determine if the reader is ready to read a line.
* @return true if it is ready
* @throws IOException if there is an error whilst determining if the reader if ready
*/
boolean ready() throws IOException;
- /**
+ /**
* Read a "line" and return an object representing it.
* @return an object representing a single line
* @throws IOException if there is an error reading lines
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/quantization/Quantizer.java b/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
index 33a4e9a..ab291be 100644
--- a/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
+++ b/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
@@ -17,29 +17,27 @@
* under the License.
*/
package org.apache.joshua.util.quantization;
-
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-public interface Quantizer {
-
- public float read(ByteBuffer stream, int position);
-
- public void write(ByteBuffer stream, float value);
-
- public void initialize();
-
- public void add(float key);
-
- public void finalize();
-
- public String getKey();
-
- public void writeState(DataOutputStream out) throws IOException;
-
- public void readState(DataInputStream in) throws IOException;
-
- public int size();
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+public interface Quantizer {
+
+ public float read(ByteBuffer stream, int position);
+
+ public void write(ByteBuffer stream, float value);
+
+ public void initialize();
+
+ public void add(float key);
+
+ public String getKey();
+
+ public void writeState(DataOutputStream out) throws IOException;
+
+ public void readState(DataInputStream in) throws IOException;
+
+ public int size();
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java b/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
index f4765f9..39aef36 100644
--- a/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
+++ b/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
@@ -18,102 +18,97 @@
*/
package org.apache.joshua.util.quantization;
-import java.io.BufferedInputStream;
-import java.io.BufferedOutputStream;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
-import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.corpus.Vocabulary;
-public class QuantizerConfiguration {
+public class QuantizerConfiguration {
- private static final Quantizer DEFAULT;
+ private static final Quantizer DEFAULT;
- private ArrayList<Quantizer> quantizers;
- private Map<Integer, Integer> quantizerByFeatureId;
+ private ArrayList<Quantizer> quantizers;
+ private Map<Integer, Integer> quantizerByFeatureId;
- static {
- DEFAULT = new BooleanQuantizer();
- }
+ static {
+ DEFAULT = new BooleanQuantizer();
+ }
- public QuantizerConfiguration() {
- quantizers = new ArrayList<Quantizer>();
- quantizerByFeatureId = new HashMap<Integer, Integer>();
- }
+ public QuantizerConfiguration() {
+ quantizers = new ArrayList<Quantizer>();
+ quantizerByFeatureId = new HashMap<Integer, Integer>();
+ }
- public void add(String quantizer_key, List<Integer> feature_ids) {
- Quantizer q = QuantizerFactory.get(quantizer_key);
- quantizers.add(q);
- int index = quantizers.size() - 1;
- for (int feature_id : feature_ids)
- quantizerByFeatureId.put(feature_id, index);
- }
+ public void add(String quantizer_key, List<Integer> feature_ids) {
+ Quantizer q = QuantizerFactory.get(quantizer_key);
+ quantizers.add(q);
+ int index = quantizers.size() - 1;
+ for (int feature_id : feature_ids)
+ quantizerByFeatureId.put(feature_id, index);
+ }
- public void initialize() {
- for (Quantizer q : quantizers)
- q.initialize();
- }
+ public void initialize() {
+ for (Quantizer q : quantizers)
+ q.initialize();
+ }
- public void finalize() {
- for (Quantizer q : quantizers)
- q.finalize();
- }
+ public final Quantizer get(int feature_id) {
+ Integer index = quantizerByFeatureId.get(feature_id);
+ return (index != null ? quantizers.get(index) : DEFAULT);
+ }
- public final Quantizer get(int feature_id) {
- Integer index = quantizerByFeatureId.get(feature_id);
- return (index != null ? quantizers.get(index) : DEFAULT);
- }
+ public void read(String file_name) throws IOException {
+ quantizers.clear();
+ quantizerByFeatureId.clear();
- public void read(String file_name) throws IOException {
- quantizers.clear();
- quantizerByFeatureId.clear();
+ File quantizer_file = new File(file_name);
+ DataInputStream in_stream =
+ new DataInputStream(new BufferedInputStream(new FileInputStream(quantizer_file)));
+ int num_quantizers = in_stream.readInt();
+ quantizers.ensureCapacity(num_quantizers);
+ for (int i = 0; i < num_quantizers; i++) {
+ String key = in_stream.readUTF();
+ Quantizer q = QuantizerFactory.get(key);
+ q.readState(in_stream);
+ quantizers.add(q);
+ }
+ int num_mappings = in_stream.readInt();
+ for (int i = 0; i < num_mappings; i++) {
+ String feature_name = in_stream.readUTF();
+ int feature_id = Vocabulary.id(feature_name);
+ int quantizer_index = in_stream.readInt();
+ if (quantizer_index >= num_quantizers) {
+ throw new RuntimeException("Error deserializing QuanitzerConfig. " + "Feature "
+ + feature_name + " referring to quantizer " + quantizer_index + " when only "
+ + num_quantizers + " known.");
+ }
+ this.quantizerByFeatureId.put(feature_id, quantizer_index);
+ }
+ in_stream.close();
+ }
- File quantizer_file = new File(file_name);
- DataInputStream in_stream =
- new DataInputStream(new BufferedInputStream(new FileInputStream(quantizer_file)));
- int num_quantizers = in_stream.readInt();
- quantizers.ensureCapacity(num_quantizers);
- for (int i = 0; i < num_quantizers; i++) {
- String key = in_stream.readUTF();
- Quantizer q = QuantizerFactory.get(key);
- q.readState(in_stream);
- quantizers.add(q);
- }
- int num_mappings = in_stream.readInt();
- for (int i = 0; i < num_mappings; i++) {
- String feature_name = in_stream.readUTF();
- int feature_id = Vocabulary.id(feature_name);
- int quantizer_index = in_stream.readInt();
- if (quantizer_index >= num_quantizers) {
- throw new RuntimeException("Error deserializing QuanitzerConfig. " + "Feature "
- + feature_name + " referring to quantizer " + quantizer_index + " when only "
- + num_quantizers + " known.");
- }
- this.quantizerByFeatureId.put(feature_id, quantizer_index);
- }
- in_stream.close();
- }
-
- public void write(String file_name) throws IOException {
- File vocab_file = new File(file_name);
- DataOutputStream out_stream =
- new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
- out_stream.writeInt(quantizers.size());
- for (int index = 0; index < quantizers.size(); index++)
- quantizers.get(index).writeState(out_stream);
- out_stream.writeInt(quantizerByFeatureId.size());
- for (int feature_id : quantizerByFeatureId.keySet()) {
- out_stream.writeUTF(Vocabulary.word(feature_id));
- out_stream.writeInt(quantizerByFeatureId.get(feature_id));
- }
- out_stream.close();
- }
+ public void write(String file_name) throws IOException {
+ File vocab_file = new File(file_name);
+ DataOutputStream out_stream =
+ new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
+ out_stream.writeInt(quantizers.size());
+ for (int index = 0; index < quantizers.size(); index++)
+ quantizers.get(index).writeState(out_stream);
+ out_stream.writeInt(quantizerByFeatureId.size());
+ for (int feature_id : quantizerByFeatureId.keySet()) {
+ out_stream.writeUTF(Vocabulary.word(feature_id));
+ out_stream.writeInt(quantizerByFeatureId.get(feature_id));
+ }
+ out_stream.close();
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/840eb4ce/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java b/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
index e81e945..a241cdf 100644
--- a/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
+++ b/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
@@ -18,21 +18,23 @@
*/
package org.apache.joshua.util.quantization;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
-abstract class StatelessQuantizer implements Quantizer {
+abstract class StatelessQuantizer implements Quantizer {
- public void initialize() {}
+ @Override
+ public void initialize() {}
- public void add(float key) {}
+ @Override
+ public void add(float key) {}
- public void finalize() {}
+ @Override
+ public void writeState(DataOutputStream out) throws IOException {
+ out.writeUTF(getKey());
+ }
- public void writeState(DataOutputStream out) throws IOException {
- out.writeUTF(getKey());
- }
-
- public void readState(DataInputStream in) throws IOException {}
+ @Override
+ public void readState(DataInputStream in) throws IOException {}
}
\ No newline at end of file
[06/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/zmert/MertCore.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/zmert/MertCore.java
index 4110c97,0000000..abed07a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/zmert/MertCore.java
+++ b/joshua-core/src/main/java/org/apache/joshua/zmert/MertCore.java
@@@ -1,3191 -1,0 +1,3048 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.zmert;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.TreeSet;
+import java.util.Vector;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Semaphore;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.metrics.EvaluationMetric;
+import org.apache.joshua.util.StreamGobbler;
++import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This code was originally written by Omar Zaidan. In September of 2012, it was augmented to support
+ * a sparse feature implementation.
- *
++ *
+ * @author Omar Zaidan
+ */
+
+public class MertCore {
+
+ private static final Logger LOG = LoggerFactory.getLogger(MertCore.class);
+
+ private final JoshuaConfiguration joshuaConfiguration;
+ private TreeSet<Integer>[] indicesOfInterest_all;
+
+ private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
- private final Runtime myRuntime = Runtime.getRuntime();
+
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+ private final static double epsilon = 1.0 / 1000000;
+
+ private int verbosity; // anything of priority <= verbosity will be printed
+ // (lower value for priority means more important)
+
+ private Random randGen;
+ private int generatedRands;
+
+ private int numSentences;
+ // number of sentences in the dev set
+ // (aka the "MERT training" set)
+
+ private int numDocuments;
+ // number of documents in the dev set
+ // this should be 1, unless doing doc-level optimization
+
+ private int[] docOfSentence;
+ // docOfSentence[i] stores which document contains the i'th sentence.
+ // docOfSentence is 0-indexed, as are the documents (i.e. first doc is indexed 0)
+
+ private int[] docSubsetInfo;
+ // stores information regarding which subset of the documents are evaluated
+ // [0]: method (0-6)
+ // [1]: first (1-indexed)
+ // [2]: last (1-indexed)
+ // [3]: size
+ // [4]: center
+ // [5]: arg1
+ // [6]: arg2
+ // [1-6] are 0 for method 0, [6] is 0 for methods 1-4 as well
+ // only [1] and [2] are needed for optimization. The rest are only needed for an output message.
+
+ private int refsPerSen;
+ // number of reference translations per sentence
+
+ private int textNormMethod;
+ // 0: no normalization, 1: "NIST-style" tokenization, and also rejoin 'm, 're, *'s, 've, 'll, 'd,
+ // and n't,
+ // 2: apply 1 and also rejoin dashes between letters, 3: apply 1 and also drop non-ASCII
+ // characters
+ // 4: apply 1+2+3
+
+ private int numParams;
+ // number of features for the log-linear model
+
+ private double[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ /* *********************************************************** */
+ /* NOTE: indexing starts at 1 in the following few arrays: */
+ /* *********************************************************** */
+
+ private String[] paramNames;
+ // feature names, needed to read/create config file
+
+ private double[] lambda;
+ // the current weight vector. NOTE: indexing starts at 1.
+
+ private boolean[] isOptimizable;
+ // isOptimizable[c] = true iff lambda[c] should be optimized
+
+ private double[] minThValue;
+ private double[] maxThValue;
+ // when investigating thresholds along the lambda[c] dimension, only values
+ // in the [minThValue[c],maxThValue[c]] range will be considered.
+ // (*) minThValue and maxThValue can be real values as well as -Infinity and +Infinity
+ // (coded as -Inf and +Inf, respectively, in an input file)
+
+ private double[] minRandValue;
+ private double[] maxRandValue;
+ // when choosing a random value for the lambda[c] parameter, it will be
+ // chosen from the [minRandValue[c],maxRandValue[c]] range.
+ // (*) minRandValue and maxRandValue must be real values, but not -Inf or +Inf
+
+ private int damianos_method;
+ private double damianos_param;
+ private double damianos_mult;
+
+ private double[] defaultLambda;
+ // "default" parameter values; simply the values read in the parameter file
+
+ /* *********************************************************** */
+ /* *********************************************************** */
+
+ private Decoder myDecoder;
+ // COMMENT OUT if decoder is not Joshua
+
+ private String decoderCommand;
+ // the command that runs the decoder; read from decoderCommandFileName
+
+ private int decVerbosity;
+ // verbosity level for decoder output. If 0, decoder output is ignored.
+ // If 1, decoder output is printed.
+
+ private int validDecoderExitValue;
+ // return value from running the decoder command that indicates success
+
+ private int numOptThreads;
+ // number of threads to run things in parallel
+
+ private int saveInterFiles;
+ // 0: nothing, 1: only configs, 2: only n-bests, 3: both configs and n-bests
+
+ private int compressFiles;
+ // should Z-MERT gzip the large files? If 0, no compression takes place.
+ // If 1, compression is performed on: decoder output files, temp sents files,
+ // and temp feats files.
+
+ private int sizeOfNBest;
+ // size of N-best list generated by decoder at each iteration
+ // (aka simply N, but N is a bad variable name)
+
+ private long seed;
+ // seed used to create random number generators
+
+ private boolean randInit;
+ // if true, parameters are initialized randomly. If false, parameters
+ // are initialized using values from parameter file.
+
+ private int initsPerIt;
+ // number of intermediate initial points per iteration
+
+ private int maxMERTIterations, minMERTIterations, prevMERTIterations;
+ // max: maximum number of MERT iterations
+ // min: minimum number of MERT iterations before an early MERT exit
+ // prev: number of previous MERT iterations from which to consider candidates (in addition to
+ // the candidates from the current iteration)
+
+ private double stopSigValue;
+ // early MERT exit if no weight changes by more than stopSigValue
+ // (but see minMERTIterations above and stopMinIts below)
+
+ private int stopMinIts;
+ // some early stopping criterion must be satisfied in stopMinIts *consecutive* iterations
+ // before an early exit (but see minMERTIterations above)
+
+ private boolean oneModificationPerIteration;
+ // if true, each MERT iteration performs at most one parameter modification.
+ // If false, a new MERT iteration starts (i.e. a new N-best list is
+ // generated) only after the previous iteration reaches a local maximum.
+
+ private String metricName;
+ // name of evaluation metric optimized by MERT
+
+ private String metricName_display;
+ // name of evaluation metric optimized by MERT, possibly with "doc-level " prefixed
+
+ private String[] metricOptions;
+ // options for the evaluation metric (e.g. for BLEU, maxGramLength and effLengthMethod)
+
+ private EvaluationMetric evalMetric;
+ // the evaluation metric used by MERT
+
+ private int suffStatsCount;
+ // number of sufficient statistics for the evaluation metric
+
+ private String tmpDirPrefix;
+ // prefix for the ZMERT.temp.* files
+
+ private boolean passIterationToDecoder;
+ // should the iteration number be passed as an argument to decoderCommandFileName?
+ // If 1, iteration number is passed. If 0, launch with no arguments.
+
+ private String dirPrefix; // where are all these files located?
+ private String paramsFileName, docInfoFileName, finalLambdaFileName;
+ private String sourceFileName, refFileName, decoderOutFileName;
+ private String decoderConfigFileName, decoderCommandFileName;
+ private String fakeFileNameTemplate, fakeFileNamePrefix, fakeFileNameSuffix;
+
+ // e.g. output.it[1-x].someOldRun would be specified as:
+ // output.it?.someOldRun
+ // and we'd have prefix = "output.it" and suffix = ".sameOldRun"
+
+ // private int useDisk;
+
- public MertCore(JoshuaConfiguration joshuaConfiguration)
++ public MertCore(JoshuaConfiguration joshuaConfiguration)
+ {
+ this.joshuaConfiguration = joshuaConfiguration;
+ }
+
- public MertCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
++ public MertCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(args);
+ initialize(0);
+ }
+
- public MertCore(String configFileName,JoshuaConfiguration joshuaConfiguration) {
++ public MertCore(String configFileName,JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(cfgFileToArgsArray(configFileName));
+ initialize(0);
+ }
+
- private void initialize(int randsToSkip) {
++ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
+ println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
+
+ randGen = new Random(seed);
+ for (int r = 1; r <= randsToSkip; ++r) {
+ randGen.nextDouble();
+ }
+ generatedRands = randsToSkip;
+
+ if (randsToSkip == 0) {
+ println("----------------------------------------------------", 1);
+ println("Initializing...", 1);
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ println("Random number generator initialized using seed: " + seed, 1);
+ println("", 1);
+ }
+
+ if (refsPerSen > 1) {
+ String refFile = refFileName + "0";
+ if (! new File(refFile).exists())
+ refFile = refFileName + ".0";
+ if (! new File(refFile).exists()) {
- throw new RuntimeException(String.format("* FATAL: can't find first reference file '%s{0,.0}'", refFileName));
++ throw new IOException(String.format("* FATAL: can't find first reference file '%s{0,.0}'", refFileName));
+ }
+
- numSentences = countLines(refFile);
++ numSentences = new ExistingUTF8EncodedTextFile(refFile).getNumberOfLines();
+ } else {
- numSentences = countLines(refFileName);
++ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines();
+ }
+
+ processDocInfo();
+ // sets numDocuments and docOfSentence[]
+
+ if (numDocuments > 1) metricName_display = "doc-level " + metricName;
+
+ set_docSubsetInfo(docSubsetInfo);
+
+
+
- numParams = countNonEmptyLines(paramsFileName) - 1;
++ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
+ // the parameter file contains one line per parameter
+ // and one line for the normalization method
+
+
+ paramNames = new String[1 + numParams];
+ lambda = new double[1 + numParams]; // indexing starts at 1 in these arrays
+ isOptimizable = new boolean[1 + numParams];
+ minThValue = new double[1 + numParams];
+ maxThValue = new double[1 + numParams];
+ minRandValue = new double[1 + numParams];
+ maxRandValue = new double[1 + numParams];
+ // precision = new double[1+numParams];
+ defaultLambda = new double[1 + numParams];
+ normalizationOptions = new double[3];
+
+ try {
+ // read parameter names
+ BufferedReader inFile_names = new BufferedReader(new FileReader(paramsFileName));
+
+ for (int c = 1; c <= numParams; ++c) {
+ String line = "";
+ while (line != null && line.length() == 0) { // skip empty lines
+ line = inFile_names.readLine();
+ }
+ String paramName = (line.substring(0, line.indexOf("|||"))).trim();
+ paramNames[c] = paramName;
+ }
+
+ inFile_names.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ processParamFile();
+ // sets the arrays declared just above
+
+ // SentenceInfo.createV(); // uncomment ONLY IF using vocabulary implementation of SentenceInfo
+
+
+ String[][] refSentences = new String[numSentences][refsPerSen];
+
+ try {
+
+ // read in reference sentences
+ BufferedReader reference_readers[] = new BufferedReader[refsPerSen];
+ if (refsPerSen == 1) {
+ reference_readers[0] = new BufferedReader(new InputStreamReader(new FileInputStream(new File(refFileName)), "utf8"));
+ } else {
+ for (int i = 0; i < refsPerSen; i++) {
+ String refFile = refFileName + i;
+ if (! new File(refFile).exists())
+ refFile = refFileName + "." + i;
+ if (! new File(refFile).exists()) {
+ throw new RuntimeException(String.format("* FATAL: can't find reference file '%s'", refFile));
+ }
+
+ reference_readers[i] = new BufferedReader(new InputStreamReader(new FileInputStream(new File(refFile)), "utf8"));
+ }
+ }
-
++
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // read the rth reference translation for the ith sentence
+ refSentences[i][r] = normalize(reference_readers[r].readLine(), textNormMethod);
+ }
+ }
+
+ // close all the reference files
- for (int i = 0; i < refsPerSen; i++)
++ for (int i = 0; i < refsPerSen; i++)
+ reference_readers[i].close();
+
+ // read in decoder command, if any
+ decoderCommand = null;
+ if (decoderCommandFileName != null) {
+ if (fileExists(decoderCommandFileName)) {
+ BufferedReader inFile_comm = new BufferedReader(new FileReader(decoderCommandFileName));
+ decoderCommand = inFile_comm.readLine();
+ inFile_comm.close();
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+
+ // set static data members for the EvaluationMetric class
+ EvaluationMetric.set_numSentences(numSentences);
+ EvaluationMetric.set_numDocuments(numDocuments);
+ EvaluationMetric.set_refsPerSen(refsPerSen);
+ EvaluationMetric.set_refSentences(refSentences);
+ EvaluationMetric.set_tmpDirPrefix(tmpDirPrefix);
+
+ evalMetric = EvaluationMetric.getMetric(metricName, metricOptions);
+
+ suffStatsCount = evalMetric.get_suffStatsCount();
+
+ // set static data members for the IntermediateOptimizer class
+ IntermediateOptimizer.set_MERTparams(numSentences, numDocuments, docOfSentence, docSubsetInfo,
+ numParams, normalizationOptions, isOptimizable, minThValue, maxThValue,
+ oneModificationPerIteration, evalMetric, tmpDirPrefix, verbosity);
+
+
+
+ if (randsToSkip == 0) { // i.e. first iteration
+ println("Number of sentences: " + numSentences, 1);
+ println("Number of documents: " + numDocuments, 1);
+ println("Optimizing " + metricName_display, 1);
+
+ print("docSubsetInfo: {", 1);
+ for (int f = 0; f < 6; ++f)
+ print(docSubsetInfo[f] + ", ", 1);
+ println(docSubsetInfo[6] + "}", 1);
+
+ println("Number of features: " + numParams, 1);
+ print("Feature names: {", 1);
+ for (int c = 1; c <= numParams; ++c) {
+ print("\"" + paramNames[c] + "\"", 1);
+ if (c < numParams) print(",", 1);
+ }
+ println("}", 1);
+ println("", 1);
+
+ println("c Default value\tOptimizable?\tCrit. val. range\tRand. val. range", 1);
+
+ for (int c = 1; c <= numParams; ++c) {
+ print(c + " " + f4.format(lambda[c]) + "\t\t", 1);
+ if (!isOptimizable[c]) {
+ println(" No", 1);
+ } else {
+ print(" Yes\t\t", 1);
+ // print("[" + minThValue[c] + "," + maxThValue[c] + "] @ " + precision[c] +
+ // " precision",1);
+ print(" [" + minThValue[c] + "," + maxThValue[c] + "]", 1);
+ print("\t\t", 1);
+ print(" [" + minRandValue[c] + "," + maxRandValue[c] + "]", 1);
+ println("", 1);
+ }
+ }
+
+ println("", 1);
+ print("Weight vector normalization method: ", 1);
+ if (normalizationOptions[0] == 0) {
+ println("none.", 1);
+ } else if (normalizationOptions[0] == 1) {
+ println("weights will be scaled so that the \"" + paramNames[(int) normalizationOptions[1]]
+ + "\" weight has an absolute value of " + normalizationOptions[2] + ".", 1);
+ } else if (normalizationOptions[0] == 2) {
+ println("weights will be scaled so that the maximum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 3) {
+ println("weights will be scaled so that the minimum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 4) {
+ println("weights will be scaled so that the L-" + normalizationOptions[1] + " norm is "
+ + normalizationOptions[2] + ".", 1);
+ }
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // rename original config file so it doesn't get overwritten
+ // (original name will be restored in finish())
+ renameFile(decoderConfigFileName, decoderConfigFileName + ".ZMERT.orig");
+
+ } // if (randsToSkip == 0)
+
+
+ if (decoderCommand == null && fakeFileNameTemplate == null) {
+ println("Loading Joshua decoder...", 1);
+ myDecoder = new Decoder(joshuaConfiguration);
+ println("...finished loading @ " + (new Date()), 1);
+ println("");
+ } else {
+ myDecoder = null;
+ }
+
+
+
+ @SuppressWarnings("unchecked")
+ TreeSet<Integer>[] temp_TSA = new TreeSet[numSentences];
+ indicesOfInterest_all = temp_TSA;
+
+ for (int i = 0; i < numSentences; ++i) {
+ indicesOfInterest_all[i] = new TreeSet<Integer>();
+ }
+
+
+ } // void initialize(...)
+
+ public void run_MERT() {
+ run_MERT(minMERTIterations, maxMERTIterations, prevMERTIterations);
+ }
+
+ public void run_MERT(int minIts, int maxIts, int prevIts) {
+ println("----------------------------------------------------", 1);
+ println("Z-MERT run started @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ if (randInit) {
+ println("Initializing lambda[] randomly.", 1);
+
+ // initialize optimizable parameters randomly (sampling uniformly from
+ // that parameter's random value range)
+ lambda = randomLambda();
+ }
+
+ println("Initial lambda[]: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ double FINAL_score = evalMetric.worstPossibleScore();
+
+
+ // int[] lastUsedIndex = new int[numSentences];
+ int[] maxIndex = new int[numSentences];
+ // used to grow featVal_array dynamically
+ // HashMap<Integer,int[]>[] suffStats_array = new HashMap[numSentences];
+ // suffStats_array[i] maps candidates of interest for sentence i to an array
+ // storing the sufficient statistics for that candidate
+ for (int i = 0; i < numSentences; ++i) {
+ // lastUsedIndex[i] = -1;
+ maxIndex[i] = sizeOfNBest - 1;
+ // suffStats_array[i] = new HashMap<Integer,int[]>();
+ }
+ /*
+ * double[][][] featVal_array = new double[1+numParams][][]; // indexed by
+ * [param][sentence][candidate] featVal_array[0] = null; // param indexing starts at 1 for (int
+ * c = 1; c <= numParams; ++c) { featVal_array[c] = new double[numSentences][]; for (int i = 0;
+ * i < numSentences; ++i) { featVal_array[c][i] = new double[maxIndex[i]]; // will grow
+ * dynamically as needed } }
+ */
+ int earlyStop = 0;
+ // number of consecutive iteration an early stopping criterion was satisfied
+
+ for (int iteration = 1;; ++iteration) {
+
+ double[] A = run_single_iteration(iteration, minIts, maxIts, prevIts, earlyStop, maxIndex);
+ if (A != null) {
+ FINAL_score = A[0];
+ earlyStop = (int) A[1];
+ if (A[2] == 1) break;
+ } else {
+ break;
+ }
+
+ } // for (iteration)
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("Z-MERT run ended @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+ println("FINAL lambda: " + lambdaToString(lambda) + " (" + metricName_display + ": "
+ + FINAL_score + ")", 1);
+ // check if a lambda is outside its threshold range
+ for (int c = 1; c <= numParams; ++c) {
+ if (lambda[c] < minThValue[c] || lambda[c] > maxThValue[c]) {
+ println("Warning: after normalization, lambda[" + c + "]=" + f4.format(lambda[c])
+ + " is outside its critical value range.", 1);
+ }
+ }
+ println("", 1);
+
+ // delete intermediate .temp.*.it* decoder output files
+ for (int iteration = 1; iteration <= maxIts; ++iteration) {
+ if (compressFiles == 1) {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration + ".gz");
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration + ".gz");
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz");
+ }
+ } else {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration);
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration);
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+ }
+ }
+
+ } // void run_MERT(int maxIts)
+
+
+ @SuppressWarnings("unchecked")
+ public double[] run_single_iteration(int iteration, int minIts, int maxIts, int prevIts,
+ int earlyStop, int[] maxIndex) {
+ double FINAL_score = 0;
+
+ double[] retA = new double[3];
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+
+ boolean done = false;
+ retA[2] = 1; // will only be made 0 if we don't break from the following loop
+
+
+ double[][][] featVal_array = new double[1 + numParams][][];
+ // indexed by [param][sentence][candidate]
+ featVal_array[0] = null; // param indexing starts at 1
+ for (int c = 1; c <= numParams; ++c) {
+ featVal_array[c] = new double[numSentences][];
+ for (int i = 0; i < numSentences; ++i) {
+ featVal_array[c][i] = new double[maxIndex[i] + 1];
+ // will grow dynamically as needed
+ }
+ }
+
+
+ while (!done) { // NOTE: this "loop" will only be carried out once
+ println("--- Starting Z-MERT iteration #" + iteration + " @ " + (new Date()) + " ---", 1);
+
+ // printMemoryUsage();
+
+ // run the decoder on all the sentences, producing for each sentence a set of
+ // sizeOfNBest candidates, with numParams feature values for each candidate
+
+ /******************************/
+ // CREATE DECODER CONFIG FILE //
+ /******************************/
+
+ createConfigFile(lambda, decoderConfigFileName, decoderConfigFileName + ".ZMERT.orig");
+ // i.e. use the original config file as a template
+
+ /***************/
+ // RUN DECODER //
+ /***************/
+
+ if (iteration == 1) {
+ println("Decoding using initial weight vector " + lambdaToString(lambda), 1);
+ } else {
+ println("Redecoding using weight vector " + lambdaToString(lambda), 1);
+ }
+
+ String[] decRunResult = run_decoder(iteration); // iteration passed in case fake decoder will
+ // be used
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ if (!decRunResult[1].equals("2")) {
+ println("...finished decoding @ " + (new Date()), 1);
+ }
+
+ checkFile(decRunResult[0]);
+
+ println("Producing temp files for iteration " + iteration, 3);
+
+ produceTempFiles(decRunResult[0], iteration);
+
+ if (saveInterFiles == 1 || saveInterFiles == 3) { // make copy of intermediate config file
+ if (!copyFile(decoderConfigFileName, decoderConfigFileName + ".ZMERT.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder config file (to create"
+ + decoderConfigFileName + ".ZMERT.it" + iteration + ") was unsuccessful!", 1);
+ }
+ }
+ if (saveInterFiles == 2 || saveInterFiles == 3) { // make copy of intermediate decoder output
+ // file...
+
+ if (!decRunResult[1].equals("2")) { // ...but only if no fake decoder
+ if (!decRunResult[0].endsWith(".gz")) {
+ if (!copyFile(decRunResult[0], decRunResult[0] + ".ZMERT.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder output file (to create"
+ + decRunResult[0] + ".ZMERT.it" + iteration + ") was unsuccessful!", 1);
+ }
+ } else {
+ String prefix = decRunResult[0].substring(0, decRunResult[0].length() - 3);
+ if (!copyFile(prefix + ".gz", prefix + ".ZMERT.it" + iteration + ".gz")) {
+ println("Warning: attempt to make copy of decoder output file (to create" + prefix
+ + ".ZMERT.it" + iteration + ".gz" + ") was unsuccessful!", 1);
+ }
+ }
+
+ if (compressFiles == 1 && !decRunResult[0].endsWith(".gz")) {
+ gzipFile(decRunResult[0] + ".ZMERT.it" + iteration);
+ }
+ } // if (!fake)
+
+ }
+
+ int[] candCount = new int[numSentences];
+ int[] lastUsedIndex = new int[numSentences];
+ ConcurrentHashMap<Integer, int[]>[] suffStats_array = new ConcurrentHashMap[numSentences];
+ for (int i = 0; i < numSentences; ++i) {
+ candCount[i] = 0;
+ lastUsedIndex[i] = -1;
+ // suffStats_array[i].clear();
+ suffStats_array[i] = new ConcurrentHashMap<Integer, int[]>();
+ }
+
+ double[][] initialLambda = new double[1 + initsPerIt][1 + numParams];
+ // the intermediate "initial" lambdas
+ double[][] finalLambda = new double[1 + initsPerIt][1 + numParams];
+ // the intermediate "final" lambdas
+
+ // set initialLambda[][]
+ System.arraycopy(lambda, 1, initialLambda[1], 1, numParams);
+ for (int j = 2; j <= initsPerIt; ++j) {
+ if (damianos_method == 0) {
+ initialLambda[j] = randomLambda();
+ } else {
+ initialLambda[j] =
+ randomPerturbation(initialLambda[1], iteration, damianos_method, damianos_param,
+ damianos_mult);
+ }
+ }
+
+// double[] initialScore = new double[1 + initsPerIt];
+ double[] finalScore = new double[1 + initsPerIt];
+
+ int[][][] best1Cand_suffStats = new int[1 + initsPerIt][numSentences][suffStatsCount];
+ double[][] best1Score = new double[1 + initsPerIt][numSentences];
+ // Those two arrays are used to calculate initialScore[]
+ // (the "score" in best1Score refers to that assigned by the
+ // decoder; the "score" in initialScore refers to that
+ // assigned by the evaluation metric)
+
+ int firstIt = Math.max(1, iteration - prevIts);
+ // i.e. only process candidates from the current iteration and candidates
+ // from up to prevIts previous iterations.
+ println("Reading candidate translations from iterations " + firstIt + "-" + iteration, 1);
+ println("(and computing " + metricName
+ + " sufficient statistics for previously unseen candidates)", 1);
+ print(" Progress: ");
+
+ int[] newCandidatesAdded = new int[1 + iteration];
+ for (int it = 1; it <= iteration; ++it) {
+ newCandidatesAdded[it] = 0;
+ }
+
+
+
+ try {
+
+ // each inFile corresponds to the output of an iteration
+ // (index 0 is not used; no corresponding index for the current iteration)
+ BufferedReader[] inFile_sents = new BufferedReader[iteration];
+ BufferedReader[] inFile_feats = new BufferedReader[iteration];
+ BufferedReader[] inFile_stats = new BufferedReader[iteration];
+
+ for (int it = firstIt; it < iteration; ++it) {
+ InputStream inStream_sents, inStream_feats, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_feats = new FileInputStream(tmpDirPrefix + "temp.feats.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents =
+ new GZIPInputStream(
+ new FileInputStream(tmpDirPrefix + "temp.sents.it" + it + ".gz"));
+ inStream_feats =
+ new GZIPInputStream(
+ new FileInputStream(tmpDirPrefix + "temp.feats.it" + it + ".gz"));
+ inStream_stats =
+ new GZIPInputStream(
+ new FileInputStream(tmpDirPrefix + "temp.stats.it" + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_feats[it] = new BufferedReader(new InputStreamReader(inStream_feats, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+
+ InputStream inStream_sentsCurrIt, inStream_featsCurrIt, inStream_statsCurrIt;
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ inStream_featsCurrIt = new FileInputStream(tmpDirPrefix + "temp.feats.it" + iteration);
+ } else {
+ inStream_sentsCurrIt =
+ new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration
+ + ".gz"));
+ inStream_featsCurrIt =
+ new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.feats.it" + iteration
+ + ".gz"));
+ }
+
+ BufferedReader inFile_sentsCurrIt =
+ new BufferedReader(new InputStreamReader(inStream_sentsCurrIt, "utf8"));
+ BufferedReader inFile_featsCurrIt =
+ new BufferedReader(new InputStreamReader(inStream_featsCurrIt, "utf8"));
+
+ BufferedReader inFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below
+ // is set to true
+ PrintWriter outFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below is
+ // set to false
+ boolean statsCurrIt_exists = false;
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration)) {
+ inStream_statsCurrIt = new FileInputStream(tmpDirPrefix + "temp.stats.it" + iteration);
+ inFile_statsCurrIt =
+ new BufferedReader(new InputStreamReader(inStream_statsCurrIt, "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration, tmpDirPrefix + "temp.stats.it"
+ + iteration + ".copy");
+ } else if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".gz")) {
+ inStream_statsCurrIt =
+ new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it" + iteration
+ + ".gz"));
+ inFile_statsCurrIt =
+ new BufferedReader(new InputStreamReader(inStream_statsCurrIt, "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz", tmpDirPrefix
+ + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ outFile_statsCurrIt = new PrintWriter(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ PrintWriter outFile_statsMerged = new PrintWriter(tmpDirPrefix + "temp.stats.merged");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+ PrintWriter outFile_statsMergedKnown =
+ new PrintWriter(tmpDirPrefix + "temp.stats.mergedKnown");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+
+ FileOutputStream outStream_unknownCands =
+ new FileOutputStream(tmpDirPrefix + "temp.currIt.unknownCands", false);
+ OutputStreamWriter outStreamWriter_unknownCands =
+ new OutputStreamWriter(outStream_unknownCands, "utf8");
+ BufferedWriter outFile_unknownCands = new BufferedWriter(outStreamWriter_unknownCands);
+
+ PrintWriter outFile_unknownIndices =
+ new PrintWriter(tmpDirPrefix + "temp.currIt.unknownIndices");
+
+
+ String sents_str, feats_str, stats_str;
+
+ // BUG: this assumes a candidate string cannot be produced for two
+ // different source sentences, which is not necessarily true
+ // (It's not actually a bug, but only because existingCandStats gets
+ // cleared before moving to the next source sentence.)
+ // FIX: should be made an array, indexed by i
+ HashMap<String, String> existingCandStats = new HashMap<String, String>();
+ // Stores precalculated sufficient statistics for candidates, in case
+ // the same candidate is seen again. (SS stored as a String.)
+ // Q: Why do we care? If we see the same candidate again, aren't we going
+ // to ignore it? So, why do we care about the SS of this repeat candidate?
+ // A: A "repeat" candidate may not be a repeat candidate in later
+ // iterations if the user specifies a value for prevMERTIterations
+ // that causes MERT to skip candidates from early iterations.
+ double[] currFeatVal = new double[1 + numParams];
+ String[] featVal_str;
+
+ int totalCandidateCount = 0;
+
+
+
+ int[] sizeUnknown_currIt = new int[numSentences];
+
+
+
+ for (int i = 0; i < numSentences; ++i) {
+
+ for (int j = 1; j <= initsPerIt; ++j) {
+ best1Score[j][i] = NegInf;
+ }
+
+ for (int it = firstIt; it < iteration; ++it) {
+ // Why up to but *excluding* iteration?
+ // Because the last iteration is handled a little differently, since
+ // the SS must be claculated (and the corresponding file created),
+ // which is not true for previous iterations.
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence, feature values,
+ // and sufficient statistics from the various temp files
+
+ sents_str = inFile_sents[it].readLine();
+ feats_str = inFile_feats[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+
+ outFile_statsMergedKnown.println(stats_str);
+
+ featVal_str = feats_str.split("\\s+");
+
+ /* Sparse (labeled) feature version */
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair: featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
+ Double value = Double.parseDouble(pair[1]);
+ currFeatVal[c_fromParamName(name)] = value;
+ }
+ } else {
+ for (int c = 1; c <= numParams; ++c) {
+ try {
+ currFeatVal[c] = Double.parseDouble(featVal_str[c - 1]);
+ } catch (Exception e) {
+ currFeatVal[c] = 0.0;
+ }
+ // print("fV[" + c + "]=" + currFeatVal[c] + " ",4);
+ }
+ // println("",4);
+ }
+
+
+ for (int j = 1; j <= initsPerIt; ++j) {
+ double score = 0; // i.e. score assigned by decoder
+ for (int c = 1; c <= numParams; ++c) {
+ score += initialLambda[j][c] * currFeatVal[c];
+ }
+ if (score > best1Score[j][i]) {
+ best1Score[j][i] = score;
+ String[] tempStats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s)
+ best1Cand_suffStats[j][i][s] = Integer.parseInt(tempStats[s]);
+ }
+ } // for (j)
+
+ existingCandStats.put(sents_str, stats_str);
+
+ setFeats(featVal_array, i, lastUsedIndex, maxIndex, currFeatVal);
+ candCount[i] += 1;
+
+ newCandidatesAdded[it] += 1;
+
+ } // if unseen candidate
+
+ } // for (n)
+
+ } // for (it)
+
+ outFile_statsMergedKnown.println("||||||");
+
+
+ // now process the candidates of the current iteration
+ // now determine the new candidates of the current iteration
+
+ /*
+ * remember: BufferedReader inFile_sentsCurrIt BufferedReader inFile_featsCurrIt
+ * PrintWriter outFile_statsCurrIt
+ */
+
+ String[] sentsCurrIt_currSrcSent = new String[sizeOfNBest + 1];
+
+ Vector<String> unknownCands_V = new Vector<String>();
+ // which candidates (of the i'th source sentence) have not been seen before
+ // this iteration?
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence,
+ // and store it in the sentsCurrIt_currSrcSent array
+
+ sents_str = inFile_sentsCurrIt.readLine();
+ sentsCurrIt_currSrcSent[n] = sents_str; // Note: possibly "||||||"
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ unknownCands_V.add(sents_str);
+ writeLine(sents_str, outFile_unknownCands);
+ outFile_unknownIndices.println(i);
+ newCandidatesAdded[iteration] += 1;
+ existingCandStats.put(sents_str, "U"); // i.e. unknown
+ // we add sents_str to avoid duplicate entries in unknownCands_V
+ }
+
+ } // for (n)
+
+
+
+ // now unknownCands_V has the candidates for which we need to calculate
+ // sufficient statistics (for the i'th source sentence)
+ int sizeUnknown = unknownCands_V.size();
+ sizeUnknown_currIt[i] = sizeUnknown;
+
+ /*********************************************/
+ /*
+ * String[] unknownCands = new String[sizeUnknown]; unknownCands_V.toArray(unknownCands);
+ * int[] indices = new int[sizeUnknown]; for (int d = 0; d < sizeUnknown; ++d) {
+ * existingCandStats.remove(unknownCands[d]); // remove the (unknownCands[d],"U") entry
+ * from existingCandStats // (we had added it while constructing unknownCands_V to avoid
+ * duplicate entries) indices[d] = i; }
+ */
+ /*********************************************/
+
+ existingCandStats.clear();
+
+ } // for (i)
+
+ /*
+ * int[][] newSuffStats = null; if (!statsCurrIt_exists && sizeUnknown > 0) { newSuffStats =
+ * evalMetric.suffStats(unknownCands, indices); }
+ */
+
+ outFile_statsMergedKnown.close();
+ outFile_unknownCands.close();
+ outFile_unknownIndices.close();
+
+
+ for (int it = firstIt; it < iteration; ++it) {
+ inFile_sents[it].close();
+ inFile_stats[it].close();
+
+ InputStream inStream_sents, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents =
+ new GZIPInputStream(
+ new FileInputStream(tmpDirPrefix + "temp.sents.it" + it + ".gz"));
+ inStream_stats =
+ new GZIPInputStream(
+ new FileInputStream(tmpDirPrefix + "temp.stats.it" + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ inFile_sentsCurrIt.close();
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ } else {
+ inStream_sentsCurrIt =
+ new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration
+ + ".gz"));
+ }
+ inFile_sentsCurrIt =
+ new BufferedReader(new InputStreamReader(inStream_sentsCurrIt, "utf8"));
+
+
+
+ // calculate SS for unseen candidates and write them to file
+ FileInputStream inStream_statsCurrIt_unknown = null;
+ BufferedReader inFile_statsCurrIt_unknown = null;
+
+ if (!statsCurrIt_exists && newCandidatesAdded[iteration] > 0) {
+ // create the file...
+ evalMetric.createSuffStatsFile(tmpDirPrefix + "temp.currIt.unknownCands", tmpDirPrefix
+ + "temp.currIt.unknownIndices", tmpDirPrefix + "temp.stats.unknown", sizeOfNBest);
+
+ // ...and open it
+ inStream_statsCurrIt_unknown = new FileInputStream(tmpDirPrefix + "temp.stats.unknown");
+ inFile_statsCurrIt_unknown =
+ new BufferedReader(new InputStreamReader(inStream_statsCurrIt_unknown, "utf8"));
+ }
+
+ // OPEN mergedKnown file
+ FileInputStream instream_statsMergedKnown =
+ new FileInputStream(tmpDirPrefix + "temp.stats.mergedKnown");
+ BufferedReader inFile_statsMergedKnown =
+ new BufferedReader(new InputStreamReader(instream_statsMergedKnown, "utf8"));
+
+ for (int i = 0; i < numSentences; ++i) {
+
+ // reprocess candidates from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+
+ sents_str = inFile_sents[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ existingCandStats.put(sents_str, stats_str);
+ } // if unseen candidate
+
+ } // for (n)
+ } // for (it)
+
+ // copy relevant portion from mergedKnown to the merged file
+ String line_mergedKnown = inFile_statsMergedKnown.readLine();
+ while (!line_mergedKnown.equals("||||||")) {
+ outFile_statsMerged.println(line_mergedKnown);
+ line_mergedKnown = inFile_statsMergedKnown.readLine();
+ }
+
+ int[] stats = new int[suffStatsCount];
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence, feature values,
+ // and sufficient statistics from the various temp files
+
+ sents_str = inFile_sentsCurrIt.readLine();
+ feats_str = inFile_featsCurrIt.readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+
+ if (!statsCurrIt_exists) {
+ stats_str = inFile_statsCurrIt_unknown.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+
+ /*
+ * stats_str = ""; for (int s = 0; s < suffStatsCount-1; ++s) { stats[s] =
+ * newSuffStats[d][s]; stats_str += (stats[s] + " "); } stats[suffStatsCount-1] =
+ * newSuffStats[d][suffStatsCount-1]; stats_str += stats[suffStatsCount-1];
+ */
+
+ outFile_statsCurrIt.println(stats_str);
+ } else {
+ stats_str = inFile_statsCurrIt.readLine();
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ try {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ } catch (Exception e) {
+ stats[s] = 0;
+ }
+ }
+ }
+
+ outFile_statsMerged.println(stats_str);
+
+ featVal_str = feats_str.split("\\s+");
+
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair: featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
+ Double value = Double.parseDouble(pair[1]);
+ currFeatVal[c_fromParamName(name)] = value;
+ }
+ } else {
+ for (int c = 1; c <= numParams; ++c) {
+ try {
+ currFeatVal[c] = Double.parseDouble(featVal_str[c - 1]);
+ } catch (Exception e) {
+ // NumberFormatException, ArrayIndexOutOfBoundsException
+ currFeatVal[c] = 0.0;
+ }
+
+ // print("fV[" + c + "]=" + currFeatVal[c] + " ",4);
+ }
+ }
+ // println("",4);
+
+
+ for (int j = 1; j <= initsPerIt; ++j) {
+ double score = 0; // i.e. score assigned by decoder
+ for (int c = 1; c <= numParams; ++c) {
+ score += initialLambda[j][c] * currFeatVal[c];
+ }
+ if (score > best1Score[j][i]) {
+ best1Score[j][i] = score;
+ for (int s = 0; s < suffStatsCount; ++s)
+ best1Cand_suffStats[j][i][s] = stats[s];
+ }
+ } // for (j)
+
+ existingCandStats.put(sents_str, stats_str);
+
+ setFeats(featVal_array, i, lastUsedIndex, maxIndex, currFeatVal);
+ candCount[i] += 1;
+
+ // newCandidatesAdded[iteration] += 1;
+ // moved to code above detecting new candidates
+
+ } else {
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else {
+ // write SS to outFile_statsCurrIt
+ stats_str = existingCandStats.get(sents_str);
+ outFile_statsCurrIt.println(stats_str);
+ }
+ }
+
+ } // for (n)
+
+ // now d = sizeUnknown_currIt[i] - 1
+
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else
+ outFile_statsCurrIt.println("||||||");
+
+ existingCandStats.clear();
+ totalCandidateCount += candCount[i];
+
+ if ((i + 1) % 500 == 0) {
+ print((i + 1) + "\n" + " ", 1);
+ } else if ((i + 1) % 100 == 0) {
+ print("+", 1);
+ } else if ((i + 1) % 25 == 0) {
+ print(".", 1);
+ }
+
+ } // for (i)
+
+ inFile_statsMergedKnown.close();
+ outFile_statsMerged.close();
+
+ println("", 1); // finish progress line
+
+ for (int it = firstIt; it < iteration; ++it) {
+ inFile_sents[it].close();
+ inFile_feats[it].close();
+ inFile_stats[it].close();
+ }
+
+ inFile_sentsCurrIt.close();
+ inFile_featsCurrIt.close();
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.close();
+ else
+ outFile_statsCurrIt.close();
+
+ if (compressFiles == 1 && !statsCurrIt_exists) {
+ gzipFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownCands");
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownIndices");
+ deleteFile(tmpDirPrefix + "temp.stats.unknown");
+ deleteFile(tmpDirPrefix + "temp.stats.mergedKnown");
+
+ // cleanupMemory();
+
+ println("Processed " + totalCandidateCount + " distinct candidates " + "(about "
+ + totalCandidateCount / numSentences + " per sentence):", 1);
+ for (int it = firstIt; it <= iteration; ++it) {
+ println("newCandidatesAdded[it=" + it + "] = " + newCandidatesAdded[it] + " (about "
+ + newCandidatesAdded[it] / numSentences + " per sentence)", 1);
+ }
+
+ println("", 1);
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+
+ if (newCandidatesAdded[iteration] == 0) {
+ if (!oneModificationPerIteration) {
+ println("No new candidates added in this iteration; exiting Z-MERT.", 1);
+ println("", 1);
+ println("--- Z-MERT iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+ return null; // THIS MEANS THAT THE OLD VALUES SHOULD BE KEPT BY THE CALLER
+ } else {
+ println("Note: No new candidates added in this iteration.", 1);
+ }
+ }
+
+ // run the initsPerIt optimizations, in parallel, across numOptThreads threads
+ ExecutorService pool = Executors.newFixedThreadPool(numOptThreads);
+ Semaphore blocker = new Semaphore(0);
+ Vector<String>[] threadOutput = new Vector[initsPerIt + 1];
+
+ for (int j = 1; j <= initsPerIt; ++j) {
+ threadOutput[j] = new Vector<String>();
+ pool.execute(new IntermediateOptimizer(j, blocker, threadOutput[j], initialLambda[j],
+ finalLambda[j], best1Cand_suffStats[j], finalScore, candCount, featVal_array,
+ suffStats_array));
+ }
+
+ pool.shutdown();
+
+ try {
+ blocker.acquire(initsPerIt);
+ } catch (java.lang.InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ // extract output from threadOutput[]
+ for (int j = 1; j <= initsPerIt; ++j) {
+ for (String str : threadOutput[j]) {
+ println(str); // no verbosity check needed; thread already checked
+ }
+ }
+
+ int best_j = 1;
+ double bestFinalScore = finalScore[1];
+ for (int j = 2; j <= initsPerIt; ++j) {
+ if (evalMetric.isBetter(finalScore[j], bestFinalScore)) {
+ best_j = j;
+ bestFinalScore = finalScore[j];
+ }
+ }
+
+ if (initsPerIt > 1) {
+ println("Best final lambda is lambda[j=" + best_j + "] " + "(" + metricName_display + ": "
+ + f4.format(bestFinalScore) + ").", 1);
+ println("", 1);
+ }
+
+ FINAL_score = bestFinalScore;
+
+ boolean anyParamChanged = false;
+ boolean anyParamChangedSignificantly = false;
+
+ for (int c = 1; c <= numParams; ++c) {
+ if (finalLambda[best_j][c] != lambda[c]) {
+ anyParamChanged = true;
+ }
+ if (Math.abs(finalLambda[best_j][c] - lambda[c]) > stopSigValue) {
+ anyParamChangedSignificantly = true;
+ }
+ }
+
+ System.arraycopy(finalLambda[best_j], 1, lambda, 1, numParams);
+ println("--- Z-MERT iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+
+ if (!anyParamChanged) {
+ println("No parameter value changed in this iteration; exiting Z-MERT.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // check if a lambda is outside its threshold range
+ for (int c = 1; c <= numParams; ++c) {
+ if (lambda[c] < minThValue[c] || lambda[c] > maxThValue[c]) {
+ println("Warning: after normalization, lambda[" + c + "]=" + f4.format(lambda[c])
+ + " is outside its critical value range.", 1);
+ }
+ }
+
+ // was an early stopping criterion satisfied?
+ boolean critSatisfied = false;
+ if (!anyParamChangedSignificantly && stopSigValue >= 0) {
+ println("Note: No parameter value changed significantly " + "(i.e. by more than "
+ + stopSigValue + ") in this iteration.", 1);
+ critSatisfied = true;
+ }
+
+ if (critSatisfied) {
+ ++earlyStop;
+ println("", 1);
+ } else {
+ earlyStop = 0;
+ }
+
+ // if min number of iterations executed, investigate if early exit should happen
+ if (iteration >= minIts && earlyStop >= stopMinIts) {
+ println("Some early stopping criteria has been observed " + "in " + stopMinIts
+ + " consecutive iterations; exiting Z-MERT.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // if max number of iterations executed, exit
+ if (iteration >= maxIts) {
+ println("Maximum number of MERT iterations reached; exiting Z-MERT.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop
+ }
+
+ println("Next iteration will decode with lambda: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ // printMemoryUsage();
+ for (int i = 0; i < numSentences; ++i) {
+ suffStats_array[i].clear();
+ }
+ // cleanupMemory();
+ // println("",2);
+
+
+ retA[2] = 0; // i.e. this should NOT be the last iteration
+ done = true;
+
+ } // while (!done) // NOTE: this "loop" will only be carried out once
+
+
+ // delete .temp.stats.merged file, since it is not needed in the next
+ // iteration (it will be recreated from scratch)
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ retA[0] = FINAL_score;
+ retA[1] = earlyStop;
+ return retA;
+
+ } // run_single_iteration
+
+ private String lambdaToString(double[] lambdaA) {
+ String retStr = "{";
+ for (int c = 1; c <= numParams - 1; ++c) {
+ retStr += "" + lambdaA[c] + ", ";
+ }
+ retStr += "" + lambdaA[numParams] + "}";
+
+ return retStr;
+ }
+
+ private String[] run_decoder(int iteration) {
+ String[] retSA = new String[2];
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ if (fakeFileNameTemplate != null
+ && fileExists(fakeFileNamePrefix + iteration + fakeFileNameSuffix)) {
+ String fakeFileName = fakeFileNamePrefix + iteration + fakeFileNameSuffix;
+ println("Not running decoder; using " + fakeFileName + " instead.", 1);
+ /*
+ * if (fakeFileName.endsWith(".gz")) { copyFile(fakeFileName,decoderOutFileName+".gz");
+ * gunzipFile(decoderOutFileName+".gz"); } else { copyFile(fakeFileName,decoderOutFileName); }
+ */
+ retSA[0] = fakeFileName;
+ retSA[1] = "2";
+
+ } else {
+ println("Running external decoder...", 1);
+
+ try {
+ ArrayList<String> cmd = new ArrayList<String>();
+ cmd.add(decoderCommandFileName);
+
+ if (passIterationToDecoder)
+ cmd.add(Integer.toString(iteration));
+
+ ProcessBuilder pb = new ProcessBuilder(cmd);
+ // this merges the error and output streams of the subprocess
+ pb.redirectErrorStream(true);
+ Process p = pb.start();
+
+ // capture the sub-command's output
+ StreamGobbler outputGobbler = new StreamGobbler(p.getInputStream(), decVerbosity);
+ outputGobbler.start();
+
+ int decStatus = p.waitFor();
+ if (decStatus != validDecoderExitValue) {
+ throw new RuntimeException("Call to decoder returned " + decStatus + "; was expecting "
+ + validDecoderExitValue + ".");
+ }
+ } catch (IOException| InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ retSA[0] = decoderOutFileName;
+ retSA[1] = "1";
+ }
+
+ return retSA;
+
+ }
+
+ private void produceTempFiles(String nbestFileName, int iteration) {
+ try {
+ String sentsFileName = tmpDirPrefix + "temp.sents.it" + iteration;
+ String featsFileName = tmpDirPrefix + "temp.feats.it" + iteration;
+
+ FileOutputStream outStream_sents = new FileOutputStream(sentsFileName, false);
+ OutputStreamWriter outStreamWriter_sents = new OutputStreamWriter(outStream_sents, "utf8");
+ BufferedWriter outFile_sents = new BufferedWriter(outStreamWriter_sents);
+
+ PrintWriter outFile_feats = new PrintWriter(featsFileName);
+
+
+ InputStream inStream_nbest = null;
+ if (nbestFileName.endsWith(".gz")) {
+ inStream_nbest = new GZIPInputStream(new FileInputStream(nbestFileName));
+ } else {
+ inStream_nbest = new FileInputStream(nbestFileName);
+ }
+ BufferedReader inFile_nbest =
+ new BufferedReader(new InputStreamReader(inStream_nbest, "utf8"));
+
+ String line; // , prevLine;
+ String candidate_str = "";
+ String feats_str = "";
+
+ int i = 0;
+ int n = 0;
+ line = inFile_nbest.readLine();
+
+ while (line != null) {
+
+ // skip blank lines
+ if (line.equals("")) continue;
+
+ // skip lines that aren't formatted correctly
+ if (line.indexOf("|||") == -1)
+ continue;
+
+ /*
+ * line format:
- *
++ *
+ * i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
+ * .*
- *
++ *
+ * Updated September 2012: features can now be named (for sparse feature compatibility).
+ * You must name all features or none of them.
+ */
+
+ // in a well formed file, we'd find the nth candidate for the ith sentence
+
+ int read_i = Integer.parseInt((line.substring(0, line.indexOf("|||"))).trim());
+
+ if (read_i != i) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = (line.substring(line.indexOf("|||") + 3)).trim(); // get rid of initial text
+
+ candidate_str = (line.substring(0, line.indexOf("|||"))).trim();
+ feats_str = (line.substring(line.indexOf("|||") + 3)).trim();
+ // get rid of candidate string
+
+ int junk_i = feats_str.indexOf("|||");
+ if (junk_i >= 0) {
+ feats_str = (feats_str.substring(0, junk_i)).trim();
+ }
+
+ writeLine(normalize(candidate_str, textNormMethod), outFile_sents);
+ outFile_feats.println(feats_str);
+
+ ++n;
+ if (n == sizeOfNBest) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = inFile_nbest.readLine();
+ }
+
+ if (i != numSentences) { // last sentence had too few candidates
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ }
+
+ inFile_nbest.close();
+ outFile_sents.close();
+ outFile_feats.close();
+
+ if (compressFiles == 1) {
+ gzipFile(sentsFileName);
+ gzipFile(featsFileName);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ private void createConfigFile(double[] params, String cfgFileName, String templateFileName) {
+ try {
+ // i.e. create cfgFileName, which is similar to templateFileName, but with
+ // params[] as parameter values
+
+ BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
+ PrintWriter outFile = new PrintWriter(cfgFileName);
+
+ String line = inFile.readLine();
+
+ while (line != null) {
+ int c_match = -1;
+ for (int c = 1; c <= numParams; ++c) {
+ if (line.startsWith(paramNames[c] + " ")) {
+ c_match = c;
+ break;
+ }
+ }
+
+ if (c_match == -1) {
+ outFile.println(line);
+ } else {
+ outFile.println(paramNames[c_match] + " " + params[c_match]);
+ }
+
+ line = inFile.readLine();
+ }
+
+ inFile.close();
+ outFile.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void processParamFile() {
+ // process parameter file
+ Scanner inFile_init = null;
+ try {
+ inFile_init = new Scanner(new FileReader(paramsFileName));
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException("FileNotFoundException in MertCore.processParamFile(): " + e.getMessage());
+ }
+
+ String dummy = "";
+
+ // initialize lambda[] and other related arrays
+ for (int c = 1; c <= numParams; ++c) {
+ // skip parameter name
+ while (!dummy.equals("|||")) {
+ dummy = inFile_init.next();
+ }
+
+ // read default value
+ lambda[c] = inFile_init.nextDouble();
+ defaultLambda[c] = lambda[c];
+
+ // read isOptimizable
+ dummy = inFile_init.next();
+ if (dummy.equals("Opt")) {
+ isOptimizable[c] = true;
+ } else if (dummy.equals("Fix")) {
+ isOptimizable[c] = false;
+ } else {
+ throw new RuntimeException("Unknown isOptimizable string " + dummy
+ + " (must be either Opt or Fix)");
+ }
+
+ if (!isOptimizable[c]) { // skip next four values
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ } else {
+ // set minThValue[c] and maxThValue[c] (range for thresholds to investigate)
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf")) {
+ minThValue[c] = NegInf;
+ } else if (dummy.equals("+Inf")) {
+ throw new RuntimeException("minThValue[" + c + "] cannot be +Inf!");
+ } else {
+ minThValue[c] = Double.parseDouble(dummy);
+ }
+
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf")) {
+ throw new RuntimeException("maxThValue[" + c + "] cannot be -Inf!");
+ } else if (dummy.equals("+Inf")) {
+ maxThValue[c] = PosInf;
+ } else {
+ maxThValue[c] = Double.parseDouble(dummy);
+ }
+
+ // set minRandValue[c] and maxRandValue[c] (range for random values)
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("minRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ minRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("maxRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ maxRandValue[c] = Double.parseDouble(dummy);
+ }
+
+
+ // check for illogical values
+ if (minThValue[c] > maxThValue[c]) {
+ throw new RuntimeException("minThValue[" + c + "]=" + minThValue[c]
+ + " > " + maxThValue[c] + "=maxThValue[" + c + "]!");
+ }
+ if (minRandValue[c] > maxRandValue[c]) {
+ throw new RuntimeException("minRandValue[" + c + "]=" + minRandValue[c]
+ + " > " + maxRandValue[c] + "=maxRandValue[" + c + "]!");
+ }
+
+ // check for odd values
+ if (!(minThValue[c] <= lambda[c] && lambda[c] <= maxThValue[c])) {
+ println("Warning: lambda[" + c + "] has initial value (" + lambda[c] + ")", 1);
+ println(" that is outside its critical value range " + "[" + minThValue[c] + ","
+ + maxThValue[c] + "]", 1);
+ }
+
+ if (minThValue[c] == maxThValue[c]) {
+ println("Warning: lambda[" + c + "] has " + "minThValue = maxThValue = " + minThValue[c]
+ + ".", 1);
+ }
+
+ if (minRandValue[c] == maxRandValue[c]) {
+ println("Warning: lambda[" + c + "] has " + "minRandValue = maxRandValue = "
+ + minRandValue[c] + ".", 1);
+ }
+
+ if (minRandValue[c] < minThValue[c] || minRandValue[c] > maxThValue[c]
+ || maxRandValue[c] < minThValue[c] || maxRandValue[c] > maxThValue[c]) {
+ println("Warning: The random value range for lambda[" + c + "] is not contained", 1);
+ println(" within its critical value range.", 1);
+ }
+
+ } // if (!isOptimizable[c])
+
+ /*
+ * precision[c] = inFile_init.nextDouble(); if (precision[c] < 0) { println("precision[" + c +
+ * "]=" + precision[c] + " < 0! Must be non-negative."); System.exit(21); }
+ */
+
+ }
+
+ // set normalizationOptions[]
+ String origLine = "";
+ while (origLine != null && origLine.length() == 0) {
+ origLine = inFile_init.nextLine();
+ }
+
+
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ // normalization = none
+ // normalization = absval 1 lm
+ // normalization = maxabsval 1
+ // normalization = minabsval 1
+ // normalization = LNorm 2 1
+
+ dummy = (origLine.substring(origLine.indexOf("=") + 1)).trim();
+ String[] dummyA = dummy.split("\\s+");
+
+ if (dummyA[0].equals("none")) {
+ normalizationOptions[0] = 0;
+ } else if (dummyA[0].equals("absval")) {
+ normalizationOptions[0] = 1;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ String pName = dummyA[2];
+ for (int i = 3; i < dummyA.length; ++i) { // in case parameter name has multiple words
+ pName = pName + " " + dummyA[i];
+ }
+ normalizationOptions[2] = c_fromParamName(pName);;
+
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the absval normalization method must be positive.");
+ }
+ if (normalizationOptions[2] == 0) {
+ throw new RuntimeException("Unrecognized feature name " + normalizationOptions[2]
+ + " for absval normalization method.");
+ }
+ } else if (dummyA[0].equals("maxabsval")) {
+ normalizationOptions[0] = 2;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the maxabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("minabsval")) {
+ normalizationOptions[0] = 3;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the minabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("LNorm")) {
+ normalizationOptions[0] = 4;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ normalizationOptions[2] = Double.parseDouble(dummyA[2]);
+ if (normalizationOptions[1] <= 0 || normalizationOptions[2] <= 0) {
+ throw new RuntimeException("Both values for the LNorm normalization method must be positive.");
+ }
+ } else {
+ throw new RuntimeException("Unrecognized normalization method " + dummyA[0] + "; "
+ + "must be one of none, absval, maxabsval, and LNorm.");
+ } // if (dummyA[0])
+
+ inFile_init.close();
+ }
+
+ private void processDocInfo() {
+ // sets numDocuments and docOfSentence[]
+ docOfSentence = new int[numSentences];
+
+ if (docInfoFileName == null) {
+ for (int i = 0; i < numSentences; ++i)
+ docOfSentence[i] = 0;
+ numDocuments = 1;
+ } else {
+
+ try {
+
+ // 4 possible formats:
+ // 1) List of numbers, one per document, indicating # sentences in each document.
+ // 2) List of "docName size" pairs, one per document, indicating name of document and #
+ // sentences.
+ // 3) List of docName's, one per sentence, indicating which doument each sentence belongs
+ // to.
+ // 4) List of docName_number's, one per sentence, indicating which doument each sentence
+ // belongs to,
+ // and its order in that document. (can also use '-' instead of '_')
+
- int docInfoSize = countNonEmptyLines(docInfoFileName);
++ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
+
+ if (docInfoSize < numSentences) { // format #1 or #2
+ numDocuments = docInfoSize;
+ int i = 0;
+
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ String line = inFile.readLine();
+ boolean format1 = (!(line.contains(" ")));
+
+ for (int doc = 0; doc < numDocuments; ++doc) {
+
+ if (doc != 0) line = inFile.readLine();
+
+ int docSize = 0;
+ if (format1) {
+ docSize = Integer.parseInt(line);
+ } else {
+ docSize = Integer.parseInt(line.split("\\s+")[1]);
+ }
+
+ for (int i2 = 1; i2 <= docSize; ++i2) {
+ docOfSentence[i] = doc;
+ ++i;
+ }
+
+ }
+
+ // now i == numSentences
+
+ inFile.close();
+
+ } else if (docInfoSize == numSentences) { // format #3 or #4
+
+ boolean format3 = false;
+
+ HashSet<String> seenStrings = new HashSet<String>();
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ // set format3 = true if a duplicate is found
+ String line = inFile.readLine();
+ if (seenStrings.contains(line)) format3 = true;
+ seenStrings.add(line);
+ }
+
+ inFile.close();
+
+ HashSet<String> seenDocNames = new HashSet<String>();
+ HashMap<String, Integer> docOrder = new HashMap<String, Integer>();
+ // maps a document name to the order (0-indexed) in which it was seen
+
+ inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ String line = inFile.readLine();
+
+ String docName = "";
+ if (format3) {
+ docName = line;
+ } else {
+ int sep_i = Math.max(line.lastIndexOf('_'), line.lastIndexOf('-'));
+ docName = line.substring(0, sep_i);
+ }
+
+ if (!seenDocNames.contains(docName)) {
+ seenDocNames.add(docName);
+ docOrder.put(docName, seenDocNames.size() - 1);
+ }
+
+ int docOrder_i = docOrder.get(docName);
+
+ docOfSentence[i] = docOrder_i;
+
+ }
+
+ inFile.close();
+
+ numDocuments = seenDocNames.size();
+
+ } else { // badly formatted
+
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private boolean copyFile(String origFileName, String newFileName) {
+ try {
+ File inputFile = new File(origFileName);
+ File outputFile = new File(newFileName);
+
+ InputStream in = new FileInputStream(inputFile);
+ OutputStream out = new FileOutputStream(outputFile);
+
+ byte[] buffer = new byte[1024];
+ int len;
+ while ((len = in.read(buffer)) > 0) {
+ out.write(buffer, 0, len);
+ }
+ in.close();
+ out.close();
+
+ /*
+ * InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
+ * new BufferedReader(new InputStreamReader(inStream, "utf8"));
- *
++ *
+ * FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
+ * outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
+ * BufferedWriter(outStreamWriter);
- *
++ *
+ * String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
- *
++ *
+ * inFile.close(); outFile.close();
+ */
+ return true;
+ } catch (IOException e) {
+ LOG.error(e.getMessage(), e);
+ return false;
+ }
+ }
+
+ private void renameFile(String origFileName, String newFileName) {
+ if (fileExists(origFileName)) {
+ deleteFile(newFileName);
+ File oldFile = new File(origFileName);
+ File newFile = new File(newFileName);
+ if (!oldFile.renameTo(newFile)) {
+ println("Warning: attempt to rename " + origFileName + " to " + newFileName
+ + " was unsuccessful!", 1);
+ }
+ } else {
+ println("Warning: file " + origFileName + " does not exist! (in MertCore.renameFile)", 1);
+ }
+ }
+
+ private void deleteFile(String fileName) {
+ if (fileExists(fileName)) {
+ File fd = new File(fileName);
+ if (!fd.delete()) {
+ println("Warning: attempt to delete " + fileName + " was unsuccessful!", 1);
+ }
+ }
+ }
+
+ private void writeLine(String line, BufferedWriter writer) throws IOException {
+ writer.write(line, 0, line.length());
+ writer.newLine();
+ writer.flush();
+ }
+
+ public void finish() {
+ if (myDecoder != null) {
+ myDecoder.cleanUp();
+ }
+
+ // create config file with final values
+ createConfigFile(lambda, decoderConfigFileName + ".ZMERT.final", decoderConfigFileName
+ + ".ZMERT.orig");
+
+ // delete current decoder config file and decoder output
+ deleteFile(decoderConfigFileName);
+ deleteFile(decoderOutFileName);
+
+ // restore original name for config file (name was changed
+ // in initialize() so it doesn't get overwritten)
+ renameFile(decoderConfigFileName + ".ZMERT.orig", decoderConfigFileName);
+
+ if (finalLambdaFileName != null) {
+ try {
+ PrintWriter outFile_lambdas = new PrintWriter(finalLambdaFileName);
+ for (int c = 1; c <= numParams; ++c) {
+ outFile_lambdas.println(paramNames[c] + " ||| " + lambda[c]);
+ }
+ outFile_lambdas.close();
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private String[] cfgFileToArgsArray(String fileName) {
+ checkFile(fileName);
+
+ Vector<String> argsVector = new Vector<String>();
+
+ BufferedReader inFile = null;
+ try {
+ inFile = new BufferedReader(new FileReader(fileName));
+ String line, origLine;
+ do {
+ line = inFile.readLine();
+ origLine = line; // for error reporting purposes
+
+ if (line != null && line.length() > 0 && line.charAt(0) != '#') {
+
+ if (line.indexOf("#") != -1) { // discard comment
+ line = line.substring(0, line.indexOf("#"));
+ }
+
+ line = line.trim();
+
+ // now line should look like "-xxx XXX"
+
+ String[] paramA = line.split("\\s+");
+
+ if (paramA.length == 2 && paramA[0].charAt(0) == '-') {
+ argsVector.add(paramA[0]);
+ argsVector.add(paramA[1]);
+ } else if (paramA.length > 2
+ && (paramA[0].equals("-m") || paramA[0].equals("-docSet") || paramA[0]
+ .equals("-damianos"))) {
+ // -m (metricName), -docSet, and -damianos are allowed to have extra optinos
+ for (int opt = 0; opt < paramA.length; ++opt) {
+ argsVector.add(paramA[opt]);
+ }
+ } else {
+ println("Malformed line in config file:");
+ println(origLine);
+ System.exit(70);
+ }
+
+ }
+ } while (line != null);
+
+ inFile.close();
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException("Z-MERT configuration file " + fileName + " was not found!", e);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ String[] argsArray = new String[argsVector.size()];
+
+ for (int i = 0; i < argsVector.size(); ++i) {
+ argsArray[i] = argsVector.elementAt(i);
+ }
+
+ return argsArray;
+ }
+
+ private void processArgsArray(String[] args) {
+ processArgsArray(args, true);
+ }
+
+ private void processArgsArray(String[] args, boolean firstTime) {
+ /* set default values */
+ // Relevant files
+ dirPrefix = null;
+ sourceFileName = null;
+ refFileName = "reference.txt";
+ refsPerSen = 1;
+ textNormMethod = 1;
+ paramsFileName = "params.txt";
+ docInfoFileName = null;
+ finalLambdaFileName = null;
+ // MERT specs
+ metricName = "BLEU";
+ metricName_display = metricName;
+ metricOptions = new String[2];
+ metricOptions[0] = "4";
+ metricOptions[1] = "closest";
+ docSubsetInfo = new int[7];
+ docSubsetInfo[0] = 0;
+ maxMERTIterations = 20;
+ prevMERTIterations = 20;
+ minMERTIterations = 5;
+ stopMinIts = 3;
+ stopSigValue = -1;
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ numOptThreads = 1;
+ saveInterFiles = 3;
+ compressFiles = 0;
+ initsPerIt = 20;
+ oneModificationPerIteration = false;
+ randInit = false;
+ seed = System.currentTimeMillis();
+ // useDisk = 2;
+ // Decoder specs
+ decoderCommandFileName = null;
+ passIterationToDecoder = false;
+ decoderOutFileName = "output.nbest";
+ validDecoderExitValue = 0;
+ decoderConfigFileName = "dec_cfg.txt";
+ sizeOfNBest = 100;
+ fakeFileNameTemplate = null;
+ fakeFileNamePrefix = null;
+ fakeFileNameSuffix = null;
+ // Output specs
+ verbosity = 1;
+ decVerbosity = 1;
+
+ damianos_method = 0;
+ damianos_param = 0.0;
+ damianos_mult = 0.0;
+
+ int i = 0;
+
+ while (i < args.length) {
+ String option = args[i];
+ // Relevant files
+ if (option.equals("-dir")) {
+ dirPrefix = args[i + 1];
+ } else if (option.equals("-s")) {
+ sourceFileName = args[i + 1];
+ } else if (option.equals("-r")) {
+ refFileName = args[i + 1];
+ } else if (option.equals("-rps")) {
+ refsPerSen = Integer.parseInt(args[i + 1]);
+ if (refsPerSen < 1) {
+ throw new RuntimeException("refsPerSen must be positive.");
+ }
+ } else if (option.equals("-txtNrm")) {
+ textNormMethod = Integer.parseInt(args[i + 1]);
+ if (textNormMethod < 0 || textNormMethod > 4) {
+ throw new RuntimeException("textNormMethod should be between 0 and 4");
+ }
+ } else if (option.equals("-p")) {
+ paramsFileName = args[i + 1];
+ } else if (option.equals("-docInfo")) {
+ docInfoFileName = args[i + 1];
+ } else if (option.equals("-fin")) {
+ finalLambdaFileName = args[i + 1];
+ // MERT specs
+ } else if (option.equals("-m")) {
+ metricName = args[i + 1];
+ metricName_display = metricName;
+ if (EvaluationMetric.knownMetricName(metricName)) {
+ int optionCount = EvaluationMetric.metricOptionCount(metricName);
+ metricOptions = new String[optionCount];
+ for (int opt = 0; opt < optionCount; ++opt) {
+ metricOptions[opt] = args[i + opt + 2];
+ }
+ i += optionCount;
+ } else {
+ throw new RuntimeException("Unknown metric name " + metricName + ".");
+ }
+ } else if (option.equals("-docSet")) {
+ String method = args[i + 1];
+
+ if (method.equals("all")) {
+ docSubsetInfo[0] = 0;
+ i += 0;
+ } else if (method.equals("bottom")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 1;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 2;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("top")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 3;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 4;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("window")) {
+ String a1 = args[i + 2];
+ a1 = a1.substring(0, a1.indexOf("d")); // size of window
+ String a2 = args[i + 4];
+ if (a2.indexOf("p") > 0) {
+ docSubsetInfo[0] = 5;
+ a2 = a2.substring(0, a2.indexOf("p"));
+ } else {
+ docSubsetInfo[0] = 6;
+ a2 = a2.substring(0, a2.indexOf("r"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a1);
+ docSubsetInfo[6] = Integer.parseInt(a2);
+ i += 3;
+ } else {
+ throw new RuntimeException("Unknown docSet method " + method + ".");
+ }
+ } else if (option.equals("-maxIt")) {
+ maxMERTIterations = Integer.parseInt(args[i + 1]);
+ if (maxMERTIterations < 1) {
+ throw new RuntimeException("maxMERTIts must be positive.");
+ }
+ } else if (option.equals("-minIt")) {
+ minMERTIterations = Integer.parseInt(args[i + 1]);
+ if (minMERTIterations < 1) {
+ throw new RuntimeException("minMERTIts must be positive.");
+ }
+ } else if (option.equals("-prevIt")) {
+ prevMERTIterations = Integer.parseInt(args[i + 1]);
+ if (prevMERTIterations < 0) {
+ throw new RuntimeException("prevMERTIts must be non-negative.");
+ }
+ } else if (option.equals("-stopIt")) {
+ stopMinIts = Integer.parseInt(args[i + 1]);
+ if (stopMinIts < 1) {
+ throw new RuntimeException("stopMinIts must be positive.");
+ }
+ } else if (option.equals("-stopSig")) {
+ stopSigValue = Double.parseDouble(args[i + 1]);
+ }
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ else if (option.equals("-thrCnt")) {
+ numOptThreads = Integer.parseInt(args[i + 1]);
+ if (numOptThreads < 1) {
+ throw new RuntimeException("threadCount must be positive.");
+ }
+ } else if (option.equals("-save")) {
+ saveInterFiles = Integer.parseInt(args[i + 1]);
+ if (saveInterFiles < 0 || saveInterFiles > 3) {
+ throw new RuntimeException("save should be between 0 and 3");
+ }
+ } else if (option.equals("-compress")) {
+ compressFiles = Integer.parseInt(args[i + 1]);
+ if (compressFiles < 0 || compressFiles > 1) {
+ throw new RuntimeException("compressFiles should be either 0 or 1");
+ }
+ } else if (option.equals("-ipi")) {
+ initsPerIt = Integer.parseInt(args[i + 1]);
+ if (initsPerIt < 1) {
+ throw new RuntimeException("initsPerIt must be positive.");
+ }
+ } else if (option.equals("-opi")) {
+ int opi = Integer.parseInt(args[i + 1]);
+ if (opi == 1) {
+ oneModificationPer
<TRUNCATED>
[07/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
index 0aa41af,0000000..5226b0a
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/encoding/FeatureTypeAnalyzer.java
@@@ -1,256 -1,0 +1,264 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.encoding;
+
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.ff.FeatureMap;
+import org.apache.joshua.util.io.LineReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class FeatureTypeAnalyzer {
+
+ private static final Logger LOG = LoggerFactory.getLogger(FeatureTypeAnalyzer.class);
+
+ private ArrayList<FeatureType> types;
+
+ private Map<Integer, Integer> featureToType;
+
+ private Map<Integer, Integer> featureIdMap;
+
+ // Is the feature setup labeled.
+ private boolean labeled;
+
+ // Is the encoder configuration open for new features (that are not assumed boolean)?
+ private boolean open;
+
+ public FeatureTypeAnalyzer() {
+ this(false);
+ }
+
+ public FeatureTypeAnalyzer(boolean open) {
+ this.open = open;
+ this.types = new ArrayList<FeatureType>();
+ this.featureToType = new HashMap<Integer, Integer>();
+ this.featureIdMap = new HashMap<Integer, Integer>();
+ }
+
+ public void readConfig(String config_filename) throws IOException {
- LineReader reader = new LineReader(config_filename);
- while (reader.hasNext()) {
- // Clean up line, chop comments off and skip if the result is empty.
- String line = reader.next().trim();
- if (line.indexOf('#') != -1)
- line = line.substring(0, line.indexOf('#'));
- if (line.isEmpty())
- continue;
- String[] fields = line.split("[\\s]+");
-
- if ("encoder".equals(fields[0])) {
- // Adding an encoder to the mix.
- if (fields.length < 3) {
- throw new RuntimeException("Incomplete encoder line in config.");
++ try(LineReader reader = new LineReader(config_filename);) {
++ while (reader.hasNext()) {
++ // Clean up line, chop comments off and skip if the result is empty.
++ String line = reader.next().trim();
++ if (line.indexOf('#') != -1)
++ line = line.substring(0, line.indexOf('#'));
++ if (line.isEmpty())
++ continue;
++ String[] fields = line.split("[\\s]+");
++
++ if ("encoder".equals(fields[0])) {
++ // Adding an encoder to the mix.
++ if (fields.length < 3) {
++ throw new RuntimeException("Incomplete encoder line in config.");
++ }
++ String encoder_key = fields[1];
++ List<Integer> feature_ids = new ArrayList<Integer>();
++ for (int i = 2; i < fields.length; i++)
++ feature_ids.add(Vocabulary.id(fields[i]));
++ addFeatures(encoder_key, feature_ids);
+ }
+ String encoder_key = fields[1];
- ArrayList<Integer> feature_ids = new ArrayList<Integer>();
++ List<Integer> feature_ids = new ArrayList<Integer>();
+ for (int i = 2; i < fields.length; i++)
+ feature_ids.add(FeatureMap.hashFeature(fields[i]));
+ addFeatures(encoder_key, feature_ids);
+ }
+ }
+ }
+
+ public void addFeatures(String encoder_key, List<Integer> feature_ids) {
+ int index = addType(encoder_key);
+ for (int feature_id : feature_ids)
+ featureToType.put(feature_id, index);
+ }
+
+ private int addType(String encoder_key) {
+ FeatureType ft = new FeatureType(encoder_key);
+ int index = types.indexOf(ft);
+ if (index < 0) {
+ types.add(ft);
+ return types.size() - 1;
+ }
+ return index;
+ }
+
+ private int addType() {
+ types.add(new FeatureType());
+ return types.size() - 1;
+ }
+
+ public void observe(int feature_id, float value) {
+ Integer type_id = featureToType.get(feature_id);
+ if (type_id == null && open) {
+ type_id = addType();
+ featureToType.put(feature_id, type_id);
+ }
+ if (type_id != null)
+ types.get(type_id).observe(value);
+ }
+
+ // Inspects the collected histograms, inferring actual type of feature. Then replaces the
+ // analyzer, if present, with the most compact applicable type.
+ public void inferTypes(boolean labeled) {
+ for (FeatureType ft : types) {
+ ft.inferUncompressedType();
+ }
+ if (LOG.isInfoEnabled()) {
+ for (int id : featureToType.keySet()) {
+ LOG.info("Type inferred: {} is {}", (labeled ? FeatureMap.getFeature(id) : "Feature " + id),
+ types.get(featureToType.get(id)).encoder.getKey());
+ }
+ }
+ }
+
+ public void buildFeatureMap() {
+ int[] known_features = new int[featureToType.keySet().size()];
+ int i = 0;
+ for (int f : featureToType.keySet())
+ known_features[i++] = f;
+ Arrays.sort(known_features);
+
+ featureIdMap.clear();
+ for (i = 0; i < known_features.length; ++i)
+ featureIdMap.put(known_features[i], i);
+ }
+
+ public int getRank(int feature_id) {
+ return featureIdMap.get(feature_id);
+ }
+
+ public IntEncoder getIdEncoder() {
+ int num_features = featureIdMap.size();
+ if (num_features <= Byte.MAX_VALUE)
+ return PrimitiveIntEncoder.BYTE;
+ else if (num_features <= Character.MAX_VALUE)
+ return PrimitiveIntEncoder.CHAR;
+ else
+ return PrimitiveIntEncoder.INT;
+ }
+
+ public void write(String file_name) throws IOException {
+ File out_file = new File(file_name);
+ BufferedOutputStream buf_stream = new BufferedOutputStream(new FileOutputStream(out_file));
+ DataOutputStream out_stream = new DataOutputStream(buf_stream);
+
+ buildFeatureMap();
+
+ getIdEncoder().writeState(out_stream);
+ out_stream.writeBoolean(labeled);
+ out_stream.writeInt(types.size());
+ for (int index = 0; index < types.size(); index++)
+ types.get(index).encoder.writeState(out_stream);
+
+ out_stream.writeInt(featureToType.size());
+ for (int feature_id : featureToType.keySet()) {
+ if (labeled)
+ out_stream.writeUTF(FeatureMap.getFeature(feature_id));
+ else
+ out_stream.writeInt(feature_id);
+ out_stream.writeInt(featureIdMap.get(feature_id));
+ out_stream.writeInt(featureToType.get(feature_id));
+ }
+ out_stream.close();
+ }
+
++ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ for (int feature_id : featureToType.keySet()) {
+ sb.append(types.get(featureToType.get(feature_id)).analyzer.toString(FeatureMap.getFeature(feature_id)));
+ }
+ System.out.println(sb.toString());
+ return sb.toString();
+ }
+
+ public boolean isLabeled() {
+ return labeled;
+ }
+
+ public void setLabeled(boolean labeled) {
+ this.labeled = labeled;
+ }
+
- class FeatureType {
++ static class FeatureType {
+ FloatEncoder encoder;
+ Analyzer analyzer;
+ int bits;
+
+ FeatureType() {
+ encoder = null;
+ analyzer = new Analyzer();
+ bits = -1;
+ }
+
+ FeatureType(String key) {
+ // either throws or returns non-null
+ FloatEncoder e = EncoderFactory.getFloatEncoder(key);
+ encoder = e;
+ analyzer = null;
+ bits = -1;
+ }
+
+ void inferUncompressedType() {
+ if (encoder != null)
+ return;
+ encoder = analyzer.inferUncompressedType();
+ analyzer = null;
+ }
+
+ void inferType() {
+ if (encoder != null)
+ return;
+ encoder = analyzer.inferType(bits);
+ analyzer = null;
+ }
+
+ void observe(float value) {
+ if (analyzer != null)
+ analyzer.add(value);
+ }
+
++ @Override
+ public boolean equals(Object t) {
+ if (t != null && t instanceof FeatureType) {
+ FeatureType that = (FeatureType) t;
+ if (this.encoder != null) {
+ return this.encoder.equals(that.encoder);
+ } else {
+ if (that.encoder != null)
+ return false;
+ if (this.analyzer != null)
+ return this.analyzer.equals(that.analyzer);
+ }
+ }
+ return false;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
index 0000000,0000000..42dd236
new file mode 100644
--- /dev/null
+++ b/joshua-core/src/main/java/org/apache/joshua/util/io/ExistingUTF8EncodedTextFile.java
@@@ -1,0 -1,0 +1,77 @@@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one
++ * or more contributor license agreements. See the NOTICE file
++ * distributed with this work for additional information
++ * regarding copyright ownership. The ASF licenses this file
++ * to you under the Apache License, Version 2.0 (the
++ * "License"); you may not use this file except in compliance
++ * with the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing,
++ * software distributed under the License is distributed on an
++ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++ * KIND, either express or implied. See the License for the
++ * specific language governing permissions and limitations
++ * under the License.
++ */
++package org.apache.joshua.util.io;
++
++import java.io.FileNotFoundException;
++import java.io.IOException;
++import java.nio.charset.StandardCharsets;
++import java.nio.file.Files;
++import java.nio.file.Path;
++import java.nio.file.Paths;
++import java.util.function.Predicate;
++import java.util.stream.Stream;
++
++/**
++ * A class that represents a {@link StandardCharsets#UTF_8} text file. Will
++ * throw a {@link FileNotFoundException} upon instantiation if the underlying
++ * {@link Path}, or {@link String} representing a Path, is not found.
++ */
++public class ExistingUTF8EncodedTextFile {
++ private static final Predicate<String> emptyStringPredicate = s -> s.isEmpty();
++
++ private final Path p;
++
++ public ExistingUTF8EncodedTextFile(String pathStr) throws FileNotFoundException {
++ this(Paths.get(pathStr));
++ }
++
++ public ExistingUTF8EncodedTextFile(Path p) throws FileNotFoundException {
++ this.p = p;
++ if (!Files.exists(p))
++ throw new FileNotFoundException("Did not find the file at path: " + p.toString());
++ }
++
++ /**
++ * @return the {@link Path} representing this object
++ */
++ public Path getPath() {
++ return this.p;
++ }
++
++ /**
++ * @return the number of lines in the file represented by this object
++ * @throws IOException on inability to read file (maybe it's not a text file)
++ */
++ public int getNumberOfLines() throws IOException {
++ try(Stream<String> ls = Files.lines(this.p, StandardCharsets.UTF_8);) {
++ return (int) ls.count();
++ }
++ }
++
++ /**
++ * @return the number of non-empty lines in the file represented by this object
++ * @throws IOException on inability to read file (maybe it's not a text file)
++ */
++ public int getNumberOfNonEmptyLines() throws IOException {
++ try(Stream<String> ls = Files.lines(this.p, StandardCharsets.UTF_8);) {
++ return (int) ls.filter(emptyStringPredicate.negate())
++ .count();
++ }
++ }
++}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/io/IndexedReader.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/io/IndexedReader.java
index f357e55,0000000..d206544
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/io/IndexedReader.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/io/IndexedReader.java
@@@ -1,155 -1,0 +1,160 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.io;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+
+/**
+ * Wraps a reader with "line" index information.
- *
++ *
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
+ */
+public class IndexedReader<E> implements Reader<E> {
-
+ /** A name for the type of elements the reader produces. */
+ private final String elementName;
+
+ /** The number of elements the reader has delivered so far. */
+ private int lineNumber;
+
+ /** The underlying reader. */
+ private final Reader<E> reader;
+
+ public IndexedReader(String elementName, Reader<E> reader) {
+ this.elementName = elementName;
+ this.lineNumber = 0;
+ this.reader = reader;
+ }
+
- /**
++ /**
+ * Return the number of elements delivered so far.
+ * @return integer representing the number of elements delivered so far
+ */
+ public int index() {
+ return this.lineNumber;
+ }
+
+
+ /**
+ * Wrap an IOException's message with the index when it occured.
+ * @param oldError the old {@link java.io.IOException} we wish to wrap
+ * @return the new wrapped {@link java.io.IOException}
+ */
+ public IOException wrapIOException(IOException oldError) {
+ IOException newError =
+ new IOException("At " + this.elementName + " " + this.lineNumber + ": "
+ + oldError.getMessage());
+ newError.initCause(oldError);
+ return newError;
+ }
+
+ // ===============================================================
+ // Reader
+ // ===============================================================
+
- /**
++ /**
+ * Delegated to the underlying reader.
+ * @return true if the reader is ready
+ * @throws IOException if there is an error determining readiness
+ */
+ @Override
+ public boolean ready() throws IOException {
+ try {
+ return this.reader.ready();
+ } catch (IOException oldError) {
+ throw wrapIOException(oldError);
+ }
+ }
+
+
+ /**
+ * Delegated to the underlying reader. Note that we do not have a <code>finalize()</code> method;
+ * however, when we fall out of scope, the underlying reader will too, so its finalizer may be
+ * called. For correctness, be sure to manually close all readers.
+ */
++ @Override
+ public void close() throws IOException {
+ try {
+ this.reader.close();
+ } catch (IOException oldError) {
+ throw wrapIOException(oldError);
+ }
+ }
+
+
+ /** Delegated to the underlying reader. */
++ @Override
+ public E readLine() throws IOException {
+ E line;
+ try {
+ line = this.reader.readLine();
+ } catch (IOException oldError) {
+ throw wrapIOException(oldError);
+ }
+ ++this.lineNumber;
+ return line;
+ }
+
+
+ // ===============================================================
+ // Iterable -- because sometimes Java can be very stupid
+ // ===============================================================
+
+ /** Return self as an iterator. */
++ @Override
+ public Iterator<E> iterator() {
+ return this;
+ }
+
+
+ // ===============================================================
+ // Iterator
+ // ===============================================================
+
+ /** Delegated to the underlying reader. */
++ @Override
+ public boolean hasNext() {
+ return this.reader.hasNext();
+ }
+
+
+ /** Delegated to the underlying reader. */
++ @Override
+ public E next() throws NoSuchElementException {
+ E line = this.reader.next();
+ // Let exceptions out, we'll wrap any errors a closing time.
+
+ ++this.lineNumber;
+ return line;
+ }
+
+
+ /**
+ * If the underlying reader supports removal, then so do we. Note that the {@link #index()} method
+ * returns the number of elements delivered to the client, so removing an element from the
+ * underlying collection does not affect that number.
+ */
++ @Override
+ public void remove() throws UnsupportedOperationException {
+ this.reader.remove();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/io/LineReader.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/io/LineReader.java
index d63763d,0000000..ea5d8f1
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/io/LineReader.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/io/LineReader.java
@@@ -1,368 -1,0 +1,309 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.io;
+
+import java.io.BufferedReader;
++import java.io.File;
+import java.io.FileDescriptor;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
- import java.io.File;
- import java.nio.charset.Charset;
++import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.zip.GZIPInputStream;
+
+import org.apache.joshua.decoder.Decoder;
+
+/**
+ * This class provides an Iterator interface to a BufferedReader. This covers the most common
+ * use-cases for reading from files without ugly code to check whether we got a line or not.
- *
++ *
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @author Matt Post post@cs.jhu.edu
+ */
+public class LineReader implements Reader<String>, AutoCloseable {
+
+ /*
- * Note: charset name is case-agnostic "UTF-8" is the canonical name "UTF8", "unicode-1-1-utf-8"
- * are aliases Java doesn't distinguish utf8 vs UTF-8 like Perl does
- */
- private static final Charset FILE_ENCODING = Charset.forName("UTF-8");
-
- /*
+ * The reader and its underlying input stream. We need to keep a hold of the underlying
+ * input stream so that we can query how many raw bytes it's read (for a generic progress
+ * meter that works across GZIP'ed and plain text files).
+ */
+ private BufferedReader reader;
+ private ProgressInputStream rawStream;
+
+ private String buffer;
+ private IOException error;
+
+ private int lineno = 0;
-
++
+ private boolean display_progress = false;
-
++
+ private int progress = 0;
+
+ // ===============================================================
+ // Constructors and destructors
+ // ===============================================================
+
+ /**
+ * Opens a file for iterating line by line. The special "-" filename can be used to specify
+ * STDIN. GZIP'd files are tested for automatically.
- *
++ *
+ * @param filename the file to be opened ("-" for STDIN)
+ * @throws IOException if there is an error reading the input file
+ */
+ public LineReader(String filename) throws IOException {
-
++
+ display_progress = (Decoder.VERBOSE >= 1);
-
++
+ progress = 0;
-
- InputStream stream = null;
++
++ InputStream stream = null;
+ long totalBytes = -1;
+ if (filename.equals("-")) {
+ rawStream = null;
+ stream = new FileInputStream(FileDescriptor.in);
+ } else {
+ totalBytes = new File(filename).length();
+ rawStream = new ProgressInputStream(new FileInputStream(filename), totalBytes);
+
+ try {
+ stream = new GZIPInputStream(rawStream);
+ } catch (Exception e) {
+ // GZIP ate a byte, so reset
+ rawStream.close();
+ stream = rawStream = new ProgressInputStream(new FileInputStream(filename), totalBytes);
+ }
- }
-
- this.reader = new BufferedReader(new InputStreamReader(stream, FILE_ENCODING));
++ }
++
++ this.reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8));
+ }
-
++
+ public LineReader(String filename, boolean show_progress) throws IOException {
+ this(filename);
+ display_progress = (Decoder.VERBOSE >= 1 && show_progress);
+ }
+
+
+ /**
+ * Wraps an InputStream for iterating line by line. Stream encoding is assumed to be UTF-8.
+ * @param in an {@link java.io.InputStream} to wrap and iterate over line by line
+ */
+ public LineReader(InputStream in) {
- this.reader = new BufferedReader(new InputStreamReader(in, FILE_ENCODING));
++ this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8));
+ display_progress = false;
+ }
-
++
+ /**
- * Chain to the underlying {@link ProgressInputStream}.
- *
++ * Chain to the underlying {@link ProgressInputStream}.
++ *
+ * @return an integer from 0..100, indicating how much of the file has been read.
+ */
+ public int progress() {
+ return rawStream == null ? 0 : rawStream.progress();
+ }
-
++
+ /**
+ * This method will close the file handle, and will raise any exceptions that occured during
+ * iteration. The method is idempotent, and all calls after the first are no-ops (unless the
+ * thread was interrupted or killed). For correctness, you <b>must</b> call this method before the
+ * object falls out of scope.
+ * @throws IOException if there is an error closing the file handler
+ */
++ @Override
+ public void close() throws IOException {
+
+ this.buffer = null; // Just in case it's a large string
+
+ if (null != this.reader) {
+ try {
+ // We assume the wrappers will percolate this down.
+ this.reader.close();
+
+ } catch (IOException e) {
+ // We need to trash our cached error for idempotence.
+ // Presumably the closing error is the more important
+ // one to throw.
+ this.error = null;
+ throw e;
+
+ } finally {
+ this.reader = null;
+ }
+ }
+
+ if (null != this.error) {
+ IOException e = this.error;
+ this.error = null;
+ throw e;
+ }
+ }
+
-
- /**
- * We attempt to avoid leaking file descriptors if you fail to call close before the object falls
- * out of scope. However, the language spec makes <b>no guarantees</b> about timeliness of garbage
- * collection. It is a bug to rely on this method to release the resources. Also, the garbage
- * collector will discard any exceptions that have queued up, without notifying the application in
- * any way.
- *
- * Having a finalizer means the JVM can't do "fast allocation" of LineReader objects (or
- * subclasses). This isn't too important due to disk latency, but may be worth noting.
- *
- * @see <a
- * href="http://java2go.blogspot.com/2007/09/javaone-2007-performance-tips-2-finish.html">Performance
- * Tips</a>
- * @see <a
- * href="http://www.javaworld.com/javaworld/jw-06-1998/jw-06-techniques.html?page=1">Techniques</a>
- */
- protected void finalize() throws Throwable {
- try {
- this.close();
- } catch (IOException e) {
- // Do nothing. The GC will discard the exception
- // anyways, but it may cause us to linger on the heap.
- } finally {
- super.finalize();
- }
- }
-
-
-
+ // ===============================================================
+ // Reader
+ // ===============================================================
+
+ // Copied from interface documentation.
+ /** Determine if the reader is ready to read a line. */
++ @Override
+ public boolean ready() throws IOException {
+ return this.reader.ready();
+ }
+
+
+ /**
+ * This method is like next() except that it throws the IOException directly. If there are no
+ * lines to be read then null is returned.
+ */
++ @Override
+ public String readLine() throws IOException {
+ if (this.hasNext()) {
+ String line = this.buffer;
+ this.buffer = null;
+ return line;
+
+ } else {
+ if (null != this.error) {
+ IOException e = this.error;
+ this.error = null;
+ throw e;
+ }
+ return null;
+ }
+ }
+
+
+ // ===============================================================
+ // Iterable -- because sometimes Java can be very stupid
+ // ===============================================================
+
+ /** Return self as an iterator. */
++ @Override
+ public Iterator<String> iterator() {
+ return this;
+ }
+
+
+ // ===============================================================
+ // Iterator
+ // ===============================================================
+
+ // Copied from interface documentation.
+ /**
+ * Returns <code>true</code> if the iteration has more elements. (In other words, returns
+ * <code>true</code> if <code>next</code> would return an element rather than throwing an
+ * exception.)
+ */
++ @Override
+ public boolean hasNext() {
+ if (null != this.buffer) {
+ return true;
+
+ } else if (null != this.error) {
+ return false;
+
+ } else {
+ // We're not allowed to throw IOException from within Iterator
+ try {
+ this.buffer = this.reader.readLine();
+ } catch (IOException e) {
+ this.buffer = null;
+ this.error = e;
+ return false;
+ }
+ return (null != this.buffer);
+ }
+ }
+
+
+ /**
+ * Return the next line of the file. If an error is encountered, NoSuchElementException is thrown.
+ * The actual IOException encountered will be thrown later, when the LineReader is closed. Also if
+ * there is no line to be read then NoSuchElementException is thrown.
+ */
++ @Override
+ public String next() throws NoSuchElementException {
+ if (this.hasNext()) {
+ if (display_progress) {
+ int newProgress = (reader != null) ? progress() : 100;
+// System.err.println(String.format("OLD %d NEW %d", progress, newProgress));
-
++
+ if (newProgress > progress) {
+ for (int i = progress + 1; i <= newProgress; i++)
+ if (i == 97) {
+ System.err.print("1");
+ } else if (i == 98) {
+ System.err.print("0");
+ } else if (i == 99) {
+ System.err.print("0");
+ } else if (i == 100) {
+ System.err.println("%");
+ } else if (i % 10 == 0) {
+ System.err.print(String.format("%d", i));
+ System.err.flush();
+ } else if ((i - 1) % 10 == 0)
+ ; // skip at 11 since 10, 20, etc take two digits
+ else {
+ System.err.print(".");
+ System.err.flush();
+ }
+ progress = newProgress;
+ }
+ }
-
++
+ String line = this.buffer;
+ this.lineno++;
+ this.buffer = null;
+ return line;
+ } else {
+ throw new NoSuchElementException();
+ }
+ }
-
++
+ /* Get the line number of the last line that was returned */
+ public int lineno() {
+ return this.lineno;
+ }
+
+ /** Unsupported. */
++ @Override
+ public void remove() throws UnsupportedOperationException {
+ throw new UnsupportedOperationException();
+ }
+
-
+ /**
- * Iterates over all lines, ignoring their contents, and returns the count of lines. If some lines
- * have already been read, this will return the count of remaining lines. Because no lines will
- * remain after calling this method, we implicitly call close.
- *
- * @return the number of lines read
- * @throws IOException if there is an error reading lines
- */
- public int countLines() throws IOException {
- int lines = 0;
-
- while (this.hasNext()) {
- this.next();
- lines++;
- }
- this.close();
-
- return lines;
- }
-
- /**
+ * Example usage code.
+ * @param args an input file
+ */
+ public static void main(String[] args) {
+ if (1 != args.length) {
+ System.out.println("Usage: java LineReader filename");
+ System.exit(1);
+ }
+
- try {
-
- LineReader in = new LineReader(args[0]);
- try {
- for (String line : in) {
-
- System.out.println(line);
-
- }
- } finally {
- in.close();
++ try (LineReader in = new LineReader(args[0]);) {
++ for (String line : in) {
++ System.out.println(line);
+ }
-
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/io/Reader.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/io/Reader.java
index cab6d74,0000000..e3a150e
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/io/Reader.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/io/Reader.java
@@@ -1,51 -1,0 +1,52 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.io;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+/**
+ * Common interface for Reader type objects.
- *
++ *
+ * @author wren ng thornton wren@users.sourceforge.net
+ * @version $LastChangedDate: 2009-03-26 15:06:57 -0400 (Thu, 26 Mar 2009) $
+ */
- public interface Reader<E> extends Iterable<E>, Iterator<E> {
++public interface Reader<E> extends Iterable<E>, Iterator<E>, AutoCloseable {
+
- /**
++ /**
+ * Close the reader, freeing all resources.
+ * @throws IOException if there is an error closing the reader instance
+ */
++ @Override
+ void close() throws IOException;
+
- /**
++ /**
+ * Determine if the reader is ready to read a line.
+ * @return true if it is ready
+ * @throws IOException if there is an error whilst determining if the reader if ready
+ */
+ boolean ready() throws IOException;
+
- /**
++ /**
+ * Read a "line" and return an object representing it.
+ * @return an object representing a single line
+ * @throws IOException if there is an error reading lines
+ */
+ E readLine() throws IOException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
index 33a4e9a,0000000..ab291be
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/quantization/Quantizer.java
@@@ -1,45 -1,0 +1,43 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.quantization;
-
- import java.io.DataInputStream;
- import java.io.DataOutputStream;
- import java.io.IOException;
- import java.nio.ByteBuffer;
-
- public interface Quantizer {
-
- public float read(ByteBuffer stream, int position);
-
- public void write(ByteBuffer stream, float value);
-
- public void initialize();
-
- public void add(float key);
-
- public void finalize();
-
- public String getKey();
-
- public void writeState(DataOutputStream out) throws IOException;
-
- public void readState(DataInputStream in) throws IOException;
-
- public int size();
++
++import java.io.DataInputStream;
++import java.io.DataOutputStream;
++import java.io.IOException;
++import java.nio.ByteBuffer;
++
++public interface Quantizer {
++
++ public float read(ByteBuffer stream, int position);
++
++ public void write(ByteBuffer stream, float value);
++
++ public void initialize();
++
++ public void add(float key);
++
++ public String getKey();
++
++ public void writeState(DataOutputStream out) throws IOException;
++
++ public void readState(DataInputStream in) throws IOException;
++
++ public int size();
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
index f4765f9,0000000..39aef36
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/quantization/QuantizerConfiguration.java
@@@ -1,119 -1,0 +1,114 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.quantization;
+
- import java.io.BufferedInputStream;
- import java.io.BufferedOutputStream;
- import java.io.DataInputStream;
- import java.io.DataOutputStream;
- import java.io.File;
- import java.io.FileInputStream;
- import java.io.FileOutputStream;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
++import java.io.BufferedInputStream;
++import java.io.BufferedOutputStream;
++import java.io.DataInputStream;
++import java.io.DataOutputStream;
++import java.io.File;
++import java.io.FileInputStream;
++import java.io.FileOutputStream;
++import java.io.IOException;
++import java.util.ArrayList;
++import java.util.HashMap;
++import java.util.List;
++import java.util.Map;
+
- import org.apache.joshua.corpus.Vocabulary;
++import org.apache.joshua.corpus.Vocabulary;
+
- public class QuantizerConfiguration {
++public class QuantizerConfiguration {
+
- private static final Quantizer DEFAULT;
++ private static final Quantizer DEFAULT;
+
- private ArrayList<Quantizer> quantizers;
- private Map<Integer, Integer> quantizerByFeatureId;
++ private ArrayList<Quantizer> quantizers;
++ private Map<Integer, Integer> quantizerByFeatureId;
+
- static {
- DEFAULT = new BooleanQuantizer();
- }
++ static {
++ DEFAULT = new BooleanQuantizer();
++ }
+
- public QuantizerConfiguration() {
- quantizers = new ArrayList<Quantizer>();
- quantizerByFeatureId = new HashMap<Integer, Integer>();
- }
++ public QuantizerConfiguration() {
++ quantizers = new ArrayList<Quantizer>();
++ quantizerByFeatureId = new HashMap<Integer, Integer>();
++ }
+
- public void add(String quantizer_key, List<Integer> feature_ids) {
- Quantizer q = QuantizerFactory.get(quantizer_key);
- quantizers.add(q);
- int index = quantizers.size() - 1;
- for (int feature_id : feature_ids)
- quantizerByFeatureId.put(feature_id, index);
- }
++ public void add(String quantizer_key, List<Integer> feature_ids) {
++ Quantizer q = QuantizerFactory.get(quantizer_key);
++ quantizers.add(q);
++ int index = quantizers.size() - 1;
++ for (int feature_id : feature_ids)
++ quantizerByFeatureId.put(feature_id, index);
++ }
+
- public void initialize() {
- for (Quantizer q : quantizers)
- q.initialize();
- }
++ public void initialize() {
++ for (Quantizer q : quantizers)
++ q.initialize();
++ }
+
- public void finalize() {
- for (Quantizer q : quantizers)
- q.finalize();
- }
++ public final Quantizer get(int feature_id) {
++ Integer index = quantizerByFeatureId.get(feature_id);
++ return (index != null ? quantizers.get(index) : DEFAULT);
++ }
+
- public final Quantizer get(int feature_id) {
- Integer index = quantizerByFeatureId.get(feature_id);
- return (index != null ? quantizers.get(index) : DEFAULT);
- }
++ public void read(String file_name) throws IOException {
++ quantizers.clear();
++ quantizerByFeatureId.clear();
+
- public void read(String file_name) throws IOException {
- quantizers.clear();
- quantizerByFeatureId.clear();
++ File quantizer_file = new File(file_name);
++ DataInputStream in_stream =
++ new DataInputStream(new BufferedInputStream(new FileInputStream(quantizer_file)));
++ int num_quantizers = in_stream.readInt();
++ quantizers.ensureCapacity(num_quantizers);
++ for (int i = 0; i < num_quantizers; i++) {
++ String key = in_stream.readUTF();
++ Quantizer q = QuantizerFactory.get(key);
++ q.readState(in_stream);
++ quantizers.add(q);
++ }
++ int num_mappings = in_stream.readInt();
++ for (int i = 0; i < num_mappings; i++) {
++ String feature_name = in_stream.readUTF();
++ int feature_id = Vocabulary.id(feature_name);
++ int quantizer_index = in_stream.readInt();
++ if (quantizer_index >= num_quantizers) {
++ throw new RuntimeException("Error deserializing QuanitzerConfig. " + "Feature "
++ + feature_name + " referring to quantizer " + quantizer_index + " when only "
++ + num_quantizers + " known.");
++ }
++ this.quantizerByFeatureId.put(feature_id, quantizer_index);
++ }
++ in_stream.close();
++ }
+
- File quantizer_file = new File(file_name);
- DataInputStream in_stream =
- new DataInputStream(new BufferedInputStream(new FileInputStream(quantizer_file)));
- int num_quantizers = in_stream.readInt();
- quantizers.ensureCapacity(num_quantizers);
- for (int i = 0; i < num_quantizers; i++) {
- String key = in_stream.readUTF();
- Quantizer q = QuantizerFactory.get(key);
- q.readState(in_stream);
- quantizers.add(q);
- }
- int num_mappings = in_stream.readInt();
- for (int i = 0; i < num_mappings; i++) {
- String feature_name = in_stream.readUTF();
- int feature_id = Vocabulary.id(feature_name);
- int quantizer_index = in_stream.readInt();
- if (quantizer_index >= num_quantizers) {
- throw new RuntimeException("Error deserializing QuanitzerConfig. " + "Feature "
- + feature_name + " referring to quantizer " + quantizer_index + " when only "
- + num_quantizers + " known.");
- }
- this.quantizerByFeatureId.put(feature_id, quantizer_index);
- }
- in_stream.close();
- }
-
- public void write(String file_name) throws IOException {
- File vocab_file = new File(file_name);
- DataOutputStream out_stream =
- new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
- out_stream.writeInt(quantizers.size());
- for (int index = 0; index < quantizers.size(); index++)
- quantizers.get(index).writeState(out_stream);
- out_stream.writeInt(quantizerByFeatureId.size());
- for (int feature_id : quantizerByFeatureId.keySet()) {
- out_stream.writeUTF(Vocabulary.word(feature_id));
- out_stream.writeInt(quantizerByFeatureId.get(feature_id));
- }
- out_stream.close();
- }
++ public void write(String file_name) throws IOException {
++ File vocab_file = new File(file_name);
++ DataOutputStream out_stream =
++ new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
++ out_stream.writeInt(quantizers.size());
++ for (int index = 0; index < quantizers.size(); index++)
++ quantizers.get(index).writeState(out_stream);
++ out_stream.writeInt(quantizerByFeatureId.size());
++ for (int feature_id : quantizerByFeatureId.keySet()) {
++ out_stream.writeUTF(Vocabulary.word(feature_id));
++ out_stream.writeInt(quantizerByFeatureId.get(feature_id));
++ }
++ out_stream.close();
++ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
index e81e945,0000000..a241cdf
mode 100644,000000..100644
--- a/joshua-core/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
+++ b/joshua-core/src/main/java/org/apache/joshua/util/quantization/StatelessQuantizer.java
@@@ -1,38 -1,0 +1,40 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.util.quantization;
+
- import java.io.DataInputStream;
- import java.io.DataOutputStream;
- import java.io.IOException;
++import java.io.DataInputStream;
++import java.io.DataOutputStream;
++import java.io.IOException;
+
- abstract class StatelessQuantizer implements Quantizer {
++abstract class StatelessQuantizer implements Quantizer {
+
- public void initialize() {}
++ @Override
++ public void initialize() {}
+
- public void add(float key) {}
++ @Override
++ public void add(float key) {}
+
- public void finalize() {}
++ @Override
++ public void writeState(DataOutputStream out) throws IOException {
++ out.writeUTF(getKey());
++ }
+
- public void writeState(DataOutputStream out) throws IOException {
- out.writeUTF(getKey());
- }
-
- public void readState(DataInputStream in) throws IOException {}
++ @Override
++ public void readState(DataInputStream in) throws IOException {}
+}
[16/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b0b70627/joshua-core/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
----------------------------------------------------------------------
diff --cc joshua-core/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
index 396c4dc,0000000..b21ab71
mode 100755,000000..100755
--- a/joshua-core/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
+++ b/joshua-core/src/main/java/org/apache/joshua/adagrad/AdaGradCore.java
@@@ -1,3126 -1,0 +1,2926 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.joshua.adagrad;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.TreeSet;
+import java.util.Vector;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+import org.apache.joshua.corpus.Vocabulary;
+import org.apache.joshua.decoder.Decoder;
+import org.apache.joshua.decoder.JoshuaConfiguration;
+import org.apache.joshua.metrics.EvaluationMetric;
+import org.apache.joshua.util.StreamGobbler;
-
++import org.apache.joshua.util.io.ExistingUTF8EncodedTextFile;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This code was originally written by Yuan Cao, who copied the MERT code to produce this file.
+ */
+
+public class AdaGradCore {
+
+ private static final Logger LOG = LoggerFactory.getLogger(AdaGradCore.class);
+ private final static double NegInf = (-1.0 / 0.0);
+ private final static double PosInf = (+1.0 / 0.0);
+ private final static double epsilon = 1.0 / 1000000;
+ private final static DecimalFormat f4 = new DecimalFormat("###0.0000");
+
+ private final JoshuaConfiguration joshuaConfiguration;
- private final Runtime myRuntime = Runtime.getRuntime();
+
+ private TreeSet<Integer>[] indicesOfInterest_all;
+
- private int progress;
-
+ private int verbosity; // anything of priority <= verbosity will be printed
+ // (lower value for priority means more important)
+
+ private Random randGen;
- private int generatedRands;
+
+ private int numSentences;
+ // number of sentences in the dev set
+ // (aka the "MERT training" set)
+
+ private int numDocuments;
+ // number of documents in the dev set
+ // this should be 1, unless doing doc-level optimization
+
+ private int[] docOfSentence;
+ // docOfSentence[i] stores which document contains the i'th sentence.
+ // docOfSentence is 0-indexed, as are the documents (i.e. first doc is indexed 0)
+
+ private int[] docSubsetInfo;
+ // stores information regarding which subset of the documents are evaluated
+ // [0]: method (0-6)
+ // [1]: first (1-indexed)
+ // [2]: last (1-indexed)
+ // [3]: size
+ // [4]: center
+ // [5]: arg1
+ // [6]: arg2
+ // [1-6] are 0 for method 0, [6] is 0 for methods 1-4 as well
+ // only [1] and [2] are needed for optimization. The rest are only needed for an output message.
+
+ private int refsPerSen;
+ // number of reference translations per sentence
+
+ private int textNormMethod;
+ // 0: no normalization, 1: "NIST-style" tokenization, and also rejoin 'm, 're, *'s, 've, 'll, 'd,
+ // and n't,
+ // 2: apply 1 and also rejoin dashes between letters, 3: apply 1 and also drop non-ASCII
+ // characters
+ // 4: apply 1+2+3
+
+ private int numParams;
+ // total number of firing features
+ // this number may increase overtime as new n-best lists are decoded
+ // initially it is equal to the # of params in the parameter config file
+ private int numParamsOld;
+ // number of features before observing the new features fired in the current iteration
+
+ private double[] normalizationOptions;
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ /* *********************************************************** */
+ /* NOTE: indexing starts at 1 in the following few arrays: */
+ /* *********************************************************** */
+
+ // private double[] lambda;
+ private ArrayList<Double> lambda = new ArrayList<>();
+ // the current weight vector. NOTE: indexing starts at 1.
+ private final ArrayList<Double> bestLambda = new ArrayList<>();
+ // the best weight vector across all iterations
+
+ private boolean[] isOptimizable;
+ // isOptimizable[c] = true iff lambda[c] should be optimized
+
+ private double[] minRandValue;
+ private double[] maxRandValue;
+ // when choosing a random value for the lambda[c] parameter, it will be
+ // chosen from the [minRandValue[c],maxRandValue[c]] range.
+ // (*) minRandValue and maxRandValue must be real values, but not -Inf or +Inf
+
+ private double[] defaultLambda;
+ // "default" parameter values; simply the values read in the parameter file
+ // USED FOR NON-OPTIMIZABLE (FIXED) FEATURES
+
+ /* *********************************************************** */
+ /* *********************************************************** */
+
+ private Decoder myDecoder;
+ // COMMENT OUT if decoder is not Joshua
+
+ private String decoderCommand;
+ // the command that runs the decoder; read from decoderCommandFileName
+
+ private int decVerbosity;
+ // verbosity level for decoder output. If 0, decoder output is ignored.
+ // If 1, decoder output is printed.
+
+ private int validDecoderExitValue;
+ // return value from running the decoder command that indicates success
+
+ private int numOptThreads;
+ // number of threads to run things in parallel
+
+ private int saveInterFiles;
+ // 0: nothing, 1: only configs, 2: only n-bests, 3: both configs and n-bests
+
+ private int compressFiles;
+ // should AdaGrad gzip the large files? If 0, no compression takes place.
+ // If 1, compression is performed on: decoder output files, temp sents files,
+ // and temp feats files.
+
+ private int sizeOfNBest;
+ // size of N-best list generated by decoder at each iteration
+ // (aka simply N, but N is a bad variable name)
+
+ private long seed;
+ // seed used to create random number generators
+
+ private boolean randInit;
+ // if true, parameters are initialized randomly. If false, parameters
+ // are initialized using values from parameter file.
+
+ private int maxMERTIterations, minMERTIterations, prevMERTIterations;
+ // max: maximum number of MERT iterations
+ // min: minimum number of MERT iterations before an early MERT exit
+ // prev: number of previous MERT iterations from which to consider candidates (in addition to
+ // the candidates from the current iteration)
+
+ private double stopSigValue;
+ // early MERT exit if no weight changes by more than stopSigValue
+ // (but see minMERTIterations above and stopMinIts below)
+
+ private int stopMinIts;
+ // some early stopping criterion must be satisfied in stopMinIts *consecutive* iterations
+ // before an early exit (but see minMERTIterations above)
+
+ private boolean oneModificationPerIteration;
+ // if true, each MERT iteration performs at most one parameter modification.
+ // If false, a new MERT iteration starts (i.e. a new N-best list is
+ // generated) only after the previous iteration reaches a local maximum.
+
+ private String metricName;
+ // name of evaluation metric optimized by MERT
+
+ private String metricName_display;
+ // name of evaluation metric optimized by MERT, possibly with "doc-level " prefixed
+
+ private String[] metricOptions;
+ // options for the evaluation metric (e.g. for BLEU, maxGramLength and effLengthMethod)
+
+ private EvaluationMetric evalMetric;
+ // the evaluation metric used by MERT
+
+ private int suffStatsCount;
+ // number of sufficient statistics for the evaluation metric
+
+ private String tmpDirPrefix;
+ // prefix for the AdaGrad.temp.* files
+
+ private boolean passIterationToDecoder;
+ // should the iteration number be passed as an argument to decoderCommandFileName?
+
+ // used by adagrad
+ private boolean needShuffle = true; // shuffle the training sentences or not
+ private boolean needAvg = true; // average the weihgts or not?
+ private boolean usePseudoBleu = true; // need to use pseudo corpus to compute bleu?
+ private boolean returnBest = true; // return the best weight during tuning
+ private boolean needScale = true; // need scaling?
- private String trainingMode;
+ private int oraSelectMode = 1;
+ private int predSelectMode = 1;
+ private int adagradIter = 1;
+ private int regularization = 2;
+ private int batchSize = 1;
+ private double eta;
+ private double lam;
+ private double R = 0.99; // corpus decay when pseudo corpus is used for bleu computation
+ // private double sentForScale = 0.15; //percentage of sentences for scale factor estimation
+ private double scoreRatio = 5.0; // sclale so that model_score/metric_score = scoreratio
+ private double prevMetricScore = 0; // final metric score of the previous iteration, used only
+ // when returnBest = true
+
+ private String dirPrefix; // where are all these files located?
+ private String paramsFileName, docInfoFileName, finalLambdaFileName;
+ private String sourceFileName, refFileName, decoderOutFileName;
+ private String decoderConfigFileName, decoderCommandFileName;
+ private String fakeFileNameTemplate, fakeFileNamePrefix, fakeFileNameSuffix;
+
+ // e.g. output.it[1-x].someOldRun would be specified as:
+ // output.it?.someOldRun
+ // and we'd have prefix = "output.it" and suffix = ".sameOldRun"
+
+ // private int useDisk;
+
+ public AdaGradCore(JoshuaConfiguration joshuaConfiguration) {
+ this.joshuaConfiguration = joshuaConfiguration;
+ }
+
- public AdaGradCore(String[] args, JoshuaConfiguration joshuaConfiguration) {
++ public AdaGradCore(String[] args, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(args);
+ initialize(0);
+ }
+
- public AdaGradCore(String configFileName, JoshuaConfiguration joshuaConfiguration) {
++ public AdaGradCore(String configFileName, JoshuaConfiguration joshuaConfiguration) throws FileNotFoundException, IOException {
+ this.joshuaConfiguration = joshuaConfiguration;
+ EvaluationMetric.set_knownMetrics();
+ processArgsArray(cfgFileToArgsArray(configFileName));
+ initialize(0);
+ }
+
- private void initialize(int randsToSkip) {
++ private void initialize(int randsToSkip) throws FileNotFoundException, IOException {
+ println("NegInf: " + NegInf + ", PosInf: " + PosInf + ", epsilon: " + epsilon, 4);
+
+ randGen = new Random(seed);
+ for (int r = 1; r <= randsToSkip; ++r) {
+ randGen.nextDouble();
+ }
- generatedRands = randsToSkip;
+
+ if (randsToSkip == 0) {
+ println("----------------------------------------------------", 1);
+ println("Initializing...", 1);
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ println("Random number generator initialized using seed: " + seed, 1);
+ println("", 1);
+ }
+
+ // count the total num of sentences to be decoded, reffilename is the combined reference file
+ // name(auto generated)
- numSentences = countLines(refFileName) / refsPerSen;
++ numSentences = new ExistingUTF8EncodedTextFile(refFileName).getNumberOfLines() / refsPerSen;
+
+ // ??
+ processDocInfo();
+ // sets numDocuments and docOfSentence[]
+
+ if (numDocuments > 1)
+ metricName_display = "doc-level " + metricName;
+
+ // ??
+ set_docSubsetInfo(docSubsetInfo);
+
+ // count the number of initial features
- numParams = countNonEmptyLines(paramsFileName) - 1;
++ numParams = new ExistingUTF8EncodedTextFile(paramsFileName).getNumberOfNonEmptyLines() - 1;
+ numParamsOld = numParams;
+
+ // read parameter config file
+ try {
+ // read dense parameter names
+ BufferedReader inFile_names = new BufferedReader(new FileReader(paramsFileName));
+
+ for (int c = 1; c <= numParams; ++c) {
+ String line = "";
+ while (line != null && line.length() == 0) { // skip empty lines
+ line = inFile_names.readLine();
+ }
+
+ // save feature names
+ String paramName = (line.substring(0, line.indexOf("|||"))).trim();
+ Vocabulary.id(paramName);
+ // System.err.println(String.format("VOCAB(%s) = %d", paramName, id));
+ }
+
+ inFile_names.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // the parameter file contains one line per parameter
+ // and one line for the normalization method
+ // indexing starts at 1 in these arrays
+ for (int p = 0; p <= numParams; ++p)
+ lambda.add(0d);
+ bestLambda.add(0d);
+ // why only lambda is a list? because the size of lambda
+ // may increase over time, but other arrays are specified in
+ // the param config file, only used for initialization
+ isOptimizable = new boolean[1 + numParams];
+ minRandValue = new double[1 + numParams];
+ maxRandValue = new double[1 + numParams];
+ defaultLambda = new double[1 + numParams];
+ normalizationOptions = new double[3];
+
+ // read initial param values
+ processParamFile();
+ // sets the arrays declared just above
+
+ // SentenceInfo.createV(); // uncomment ONLY IF using vocabulary implementation of SentenceInfo
+
+ String[][] refSentences = new String[numSentences][refsPerSen];
+
+ try {
+
+ // read in reference sentences
+ InputStream inStream_refs = new FileInputStream(new File(refFileName));
+ BufferedReader inFile_refs = new BufferedReader(new InputStreamReader(inStream_refs, "utf8"));
+
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // read the rth reference translation for the ith sentence
+ refSentences[i][r] = inFile_refs.readLine();
+ }
+ }
+
+ inFile_refs.close();
+
+ // normalize reference sentences
+ for (int i = 0; i < numSentences; ++i) {
+ for (int r = 0; r < refsPerSen; ++r) {
+ // normalize the rth reference translation for the ith sentence
+ refSentences[i][r] = normalize(refSentences[i][r], textNormMethod);
+ }
+ }
+
+ // read in decoder command, if any
+ decoderCommand = null;
+ if (decoderCommandFileName != null) {
+ if (fileExists(decoderCommandFileName)) {
+ BufferedReader inFile_comm = new BufferedReader(new FileReader(decoderCommandFileName));
+ decoderCommand = inFile_comm.readLine(); // READ IN DECODE COMMAND
+ inFile_comm.close();
+ }
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // set static data members for the EvaluationMetric class
+ EvaluationMetric.set_numSentences(numSentences);
+ EvaluationMetric.set_numDocuments(numDocuments);
+ EvaluationMetric.set_refsPerSen(refsPerSen);
+ EvaluationMetric.set_refSentences(refSentences);
+ EvaluationMetric.set_tmpDirPrefix(tmpDirPrefix);
+
+ evalMetric = EvaluationMetric.getMetric(metricName, metricOptions);
+ // used only if returnBest = true
+ prevMetricScore = evalMetric.getToBeMinimized() ? PosInf : NegInf;
+
+ // length of sufficient statistics
+ // for bleu: suffstatscount=8 (2*ngram+2)
+ suffStatsCount = evalMetric.get_suffStatsCount();
+
+ // set static data members for the IntermediateOptimizer class
+ /*
+ * IntermediateOptimizer.set_MERTparams(numSentences, numDocuments, docOfSentence,
+ * docSubsetInfo, numParams, normalizationOptions, isOptimizable oneModificationPerIteration,
+ * evalMetric, tmpDirPrefix, verbosity);
+ */
+
+ // print info
+ if (randsToSkip == 0) { // i.e. first iteration
+ println("Number of sentences: " + numSentences, 1);
+ println("Number of documents: " + numDocuments, 1);
+ println("Optimizing " + metricName_display, 1);
+
+ /*
+ * print("docSubsetInfo: {", 1); for (int f = 0; f < 6; ++f) print(docSubsetInfo[f] + ", ",
+ * 1); println(docSubsetInfo[6] + "}", 1);
+ */
+
+ println("Number of initial features: " + numParams, 1);
+ print("Initial feature names: {", 1);
+
+ for (int c = 1; c <= numParams; ++c)
+ print("\"" + Vocabulary.word(c) + "\"", 1);
+ println("}", 1);
+ println("", 1);
+
+ // TODO just print the correct info
+ println("c Default value\tOptimizable?\tRand. val. range", 1);
+
+ for (int c = 1; c <= numParams; ++c) {
+ print(c + " " + f4.format(lambda.get(c).doubleValue()) + "\t\t", 1);
+
+ if (!isOptimizable[c]) {
+ println(" No", 1);
+ } else {
+ print(" Yes\t\t", 1);
+ print(" [" + minRandValue[c] + "," + maxRandValue[c] + "]", 1);
+ println("", 1);
+ }
+ }
+
+ println("", 1);
+ print("Weight vector normalization method: ", 1);
+ if (normalizationOptions[0] == 0) {
+ println("none.", 1);
+ } else if (normalizationOptions[0] == 1) {
+ println(
+ "weights will be scaled so that the \""
+ + Vocabulary.word((int) normalizationOptions[2])
+ + "\" weight has an absolute value of " + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 2) {
+ println("weights will be scaled so that the maximum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 3) {
+ println("weights will be scaled so that the minimum absolute value is "
+ + normalizationOptions[1] + ".", 1);
+ } else if (normalizationOptions[0] == 4) {
+ println("weights will be scaled so that the L-" + normalizationOptions[1] + " norm is "
+ + normalizationOptions[2] + ".", 1);
+ }
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // rename original config file so it doesn't get overwritten
+ // (original name will be restored in finish())
+ renameFile(decoderConfigFileName, decoderConfigFileName + ".AdaGrad.orig");
+ } // if (randsToSkip == 0)
+
+ // by default, load joshua decoder
+ if (decoderCommand == null && fakeFileNameTemplate == null) {
+ println("Loading Joshua decoder...", 1);
+ myDecoder = new Decoder(joshuaConfiguration);
+ println("...finished loading @ " + (new Date()), 1);
+ println("");
+ } else {
+ myDecoder = null;
+ }
+
+ @SuppressWarnings("unchecked")
+ TreeSet<Integer>[] temp_TSA = new TreeSet[numSentences];
+ indicesOfInterest_all = temp_TSA;
+
+ for (int i = 0; i < numSentences; ++i) {
+ indicesOfInterest_all[i] = new TreeSet<>();
+ }
+ } // void initialize(...)
+
+ // -------------------------
+
+ public void run_AdaGrad() {
+ run_AdaGrad(minMERTIterations, maxMERTIterations, prevMERTIterations);
+ }
+
+ public void run_AdaGrad(int minIts, int maxIts, int prevIts) {
+ // FIRST, CLEAN ALL PREVIOUS TEMP FILES
+ String dir;
+ int k = tmpDirPrefix.lastIndexOf("/");
+ if (k >= 0) {
+ dir = tmpDirPrefix.substring(0, k + 1);
+ } else {
+ dir = "./";
+ }
+ String files;
+ File folder = new File(dir);
+
+ if (folder.exists()) {
+ File[] listOfFiles = folder.listFiles();
+
+ for (File listOfFile : listOfFiles) {
+ if (listOfFile.isFile()) {
+ files = listOfFile.getName();
+ if (files.startsWith("AdaGrad.temp")) {
+ deleteFile(files);
+ }
+ }
+ }
+ }
+
+ println("----------------------------------------------------", 1);
+ println("AdaGrad run started @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+
+ // if no default lambda is provided
+ if (randInit) {
+ println("Initializing lambda[] randomly.", 1);
+ // initialize optimizable parameters randomly (sampling uniformly from
+ // that parameter's random value range)
+ lambda = randomLambda();
+ }
+
+ println("Initial lambda[]: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ int[] maxIndex = new int[numSentences];
+
+ // HashMap<Integer,int[]>[] suffStats_array = new HashMap[numSentences];
+ // suffStats_array[i] maps candidates of interest for sentence i to an array
+ // storing the sufficient statistics for that candidate
+
+ int earlyStop = 0;
+ // number of consecutive iteration an early stopping criterion was satisfied
+
+ for (int iteration = 1;; ++iteration) {
+
+ // what does "A" contain?
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+ double[] A = run_single_iteration(iteration, minIts, maxIts, prevIts, earlyStop, maxIndex);
+ if (A != null) {
+ earlyStop = (int) A[1];
+ if (A[2] == 1)
+ break;
+ } else {
+ break;
+ }
+
+ } // for (iteration)
+
+ println("", 1);
+
+ println("----------------------------------------------------", 1);
+ println("AdaGrad run ended @ " + (new Date()), 1);
+ // printMemoryUsage();
+ println("----------------------------------------------------", 1);
+ println("", 1);
+ if (!returnBest)
+ println("FINAL lambda: " + lambdaToString(lambda), 1);
+ // + " (" + metricName_display + ": " + FINAL_score + ")",1);
+ else
+ println("BEST lambda: " + lambdaToString(lambda), 1);
+
+ // delete intermediate .temp.*.it* decoder output files
+ for (int iteration = 1; iteration <= maxIts; ++iteration) {
+ if (compressFiles == 1) {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration + ".gz");
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration + ".gz");
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz");
+ }
+ } else {
+ deleteFile(tmpDirPrefix + "temp.sents.it" + iteration);
+ deleteFile(tmpDirPrefix + "temp.feats.it" + iteration);
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".copy")) {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration + ".copy");
+ } else {
+ deleteFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+ }
+ }
+ } // void run_AdaGrad(int maxIts)
+
+ // this is the key function!
+ @SuppressWarnings("unchecked")
+ public double[] run_single_iteration(int iteration, int minIts, int maxIts, int prevIts,
+ int earlyStop, int[] maxIndex) {
+ double FINAL_score = 0;
+
+ double[] retA = new double[3];
+ // retA[0]: FINAL_score
+ // retA[1]: earlyStop
+ // retA[2]: should this be the last iteration?
+
+ boolean done = false;
+ retA[2] = 1; // will only be made 0 if we don't break from the following loop
+
+ // save feats and stats for all candidates(old & new)
+ HashMap<String, String>[] feat_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ feat_hash[i] = new HashMap<>();
+
+ HashMap<String, String>[] stats_hash = new HashMap[numSentences];
+ for (int i = 0; i < numSentences; i++)
+ stats_hash[i] = new HashMap<>();
+
+ while (!done) { // NOTE: this "loop" will only be carried out once
+ println("--- Starting AdaGrad iteration #" + iteration + " @ " + (new Date()) + " ---", 1);
+
+ // printMemoryUsage();
+
+ /******************************/
+ // CREATE DECODER CONFIG FILE //
+ /******************************/
+
+ createConfigFile(lambda, decoderConfigFileName, decoderConfigFileName + ".AdaGrad.orig");
+ // i.e. use the original config file as a template
+
+ /***************/
+ // RUN DECODER //
+ /***************/
+
+ if (iteration == 1) {
+ println("Decoding using initial weight vector " + lambdaToString(lambda), 1);
+ } else {
+ println("Redecoding using weight vector " + lambdaToString(lambda), 1);
+ }
+
+ // generate the n-best file after decoding
+ String[] decRunResult = run_decoder(iteration); // iteration passed in case fake decoder will
+ // be used
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ if (!decRunResult[1].equals("2")) {
+ println("...finished decoding @ " + (new Date()), 1);
+ }
+
+ checkFile(decRunResult[0]);
+
+ /************* END OF DECODING **************/
+
+ println("Producing temp files for iteration " + iteration, 3);
+
+ produceTempFiles(decRunResult[0], iteration);
+
+ // save intermedidate output files
+ // save joshua.config.adagrad.it*
+ if (saveInterFiles == 1 || saveInterFiles == 3) { // make copy of intermediate config file
+ if (!copyFile(decoderConfigFileName, decoderConfigFileName + ".AdaGrad.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder config file (to create"
+ + decoderConfigFileName + ".AdaGrad.it" + iteration + ") was unsuccessful!", 1);
+ }
+ }
+
+ // save output.nest.AdaGrad.it*
+ if (saveInterFiles == 2 || saveInterFiles == 3) { // make copy of intermediate decoder output
+ // file...
+
+ if (!decRunResult[1].equals("2")) { // ...but only if no fake decoder
+ if (!decRunResult[0].endsWith(".gz")) {
+ if (!copyFile(decRunResult[0], decRunResult[0] + ".AdaGrad.it" + iteration)) {
+ println("Warning: attempt to make copy of decoder output file (to create"
+ + decRunResult[0] + ".AdaGrad.it" + iteration + ") was unsuccessful!", 1);
+ }
+ } else {
+ String prefix = decRunResult[0].substring(0, decRunResult[0].length() - 3);
+ if (!copyFile(prefix + ".gz", prefix + ".AdaGrad.it" + iteration + ".gz")) {
+ println("Warning: attempt to make copy of decoder output file (to create" + prefix
+ + ".AdaGrad.it" + iteration + ".gz" + ") was unsuccessful!", 1);
+ }
+ }
+
+ if (compressFiles == 1 && !decRunResult[0].endsWith(".gz")) {
+ gzipFile(decRunResult[0] + ".AdaGrad.it" + iteration);
+ }
+ } // if (!fake)
+ }
+
+ // ------------- end of saving .adagrad.it* files ---------------
+
+ int[] candCount = new int[numSentences];
+ int[] lastUsedIndex = new int[numSentences];
+
+ ConcurrentHashMap[] suffStats_array = new ConcurrentHashMap[numSentences];
+ for (int i = 0; i < numSentences; ++i) {
+ candCount[i] = 0;
+ lastUsedIndex[i] = -1;
+ // suffStats_array[i].clear();
+ suffStats_array[i] = new ConcurrentHashMap<>();
+ }
+
+ // initLambda[0] is not used!
+ double[] initialLambda = new double[1 + numParams];
+ for (int i = 1; i <= numParams; ++i)
+ initialLambda[i] = lambda.get(i);
+
+ // the "score" in initialScore refers to that
+ // assigned by the evaluation metric)
+
+ // you may consider all candidates from iter 1, or from iter (iteration-prevIts) to current
+ // iteration
+ int firstIt = Math.max(1, iteration - prevIts);
+ // i.e. only process candidates from the current iteration and candidates
+ // from up to prevIts previous iterations.
+ println("Reading candidate translations from iterations " + firstIt + "-" + iteration, 1);
+ println("(and computing " + metricName
+ + " sufficient statistics for previously unseen candidates)", 1);
+ print(" Progress: ");
+
+ int[] newCandidatesAdded = new int[1 + iteration];
+ for (int it = 1; it <= iteration; ++it)
+ newCandidatesAdded[it] = 0;
+
+ try {
+ // read temp files from all past iterations
+ // 3 types of temp files:
+ // 1. output hypo at iter i
+ // 2. feature value of each hypo at iter i
+ // 3. suff stats of each hypo at iter i
+
+ // each inFile corresponds to the output of an iteration
+ // (index 0 is not used; no corresponding index for the current iteration)
+ BufferedReader[] inFile_sents = new BufferedReader[iteration];
+ BufferedReader[] inFile_feats = new BufferedReader[iteration];
+ BufferedReader[] inFile_stats = new BufferedReader[iteration];
+
+ // temp file(array) from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ InputStream inStream_sents, inStream_feats, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_feats = new FileInputStream(tmpDirPrefix + "temp.feats.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_feats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.feats.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_feats[it] = new BufferedReader(new InputStreamReader(inStream_feats, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ InputStream inStream_sentsCurrIt, inStream_featsCurrIt, inStream_statsCurrIt;
+ // temp file for current iteration!
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ inStream_featsCurrIt = new FileInputStream(tmpDirPrefix + "temp.feats.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ inStream_featsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.feats.it" + iteration + ".gz"));
+ }
+
+ BufferedReader inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_sentsCurrIt, "utf8"));
+ BufferedReader inFile_featsCurrIt = new BufferedReader(new InputStreamReader(
+ inStream_featsCurrIt, "utf8"));
+
+ BufferedReader inFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below
+ // is set to true
+ PrintWriter outFile_statsCurrIt = null; // will only be used if statsCurrIt_exists below is
+ // set to false
+
+ // just to check if temp.stat.it.iteration exists
+ boolean statsCurrIt_exists = false;
+
+ if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration)) {
+ inStream_statsCurrIt = new FileInputStream(tmpDirPrefix + "temp.stats.it" + iteration);
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration, tmpDirPrefix + "temp.stats.it"
+ + iteration + ".copy");
+ } else if (fileExists(tmpDirPrefix + "temp.stats.it" + iteration + ".gz")) {
+ inStream_statsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.stats.it" + iteration + ".gz"));
+ inFile_statsCurrIt = new BufferedReader(new InputStreamReader(inStream_statsCurrIt,
+ "utf8"));
+ statsCurrIt_exists = true;
+ copyFile(tmpDirPrefix + "temp.stats.it" + iteration + ".gz", tmpDirPrefix
+ + "temp.stats.it" + iteration + ".copy.gz");
+ } else {
+ outFile_statsCurrIt = new PrintWriter(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // output the 4^th temp file: *.temp.stats.merged
+ PrintWriter outFile_statsMerged = new PrintWriter(tmpDirPrefix + "temp.stats.merged");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+ PrintWriter outFile_statsMergedKnown = new PrintWriter(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ // write sufficient statistics from all the sentences
+ // from the output files into a single file
+
+ // output the 5^th 6^th temp file, but will be deleted at the end of the function
+ FileOutputStream outStream_unknownCands = new FileOutputStream(tmpDirPrefix
+ + "temp.currIt.unknownCands", false);
+ OutputStreamWriter outStreamWriter_unknownCands = new OutputStreamWriter(
+ outStream_unknownCands, "utf8");
+ BufferedWriter outFile_unknownCands = new BufferedWriter(outStreamWriter_unknownCands);
+
+ PrintWriter outFile_unknownIndices = new PrintWriter(tmpDirPrefix
+ + "temp.currIt.unknownIndices");
+
+ String sents_str, feats_str, stats_str;
+
+ // BUG: this assumes a candidate string cannot be produced for two
+ // different source sentences, which is not necessarily true
+ // (It's not actually a bug, but only because existingCandStats gets
+ // cleared before moving to the next source sentence.)
+ // FIX: should be made an array, indexed by i
+ HashMap<String, String> existingCandStats = new HashMap<>();
+ // VERY IMPORTANT:
+ // A CANDIDATE X MAY APPEARED IN ITER 1, ITER 3
+ // BUT IF THE USER SPECIFIED TO CONSIDER ITERATIONS FROM ONLY ITER 2, THEN
+ // X IS NOT A "REPEATED" CANDIDATE IN ITER 3. THEREFORE WE WANT TO KEEP THE
+ // SUFF STATS FOR EACH CANDIDATE(TO SAVE COMPUTATION IN THE FUTURE)
+
+ // Stores precalculated sufficient statistics for candidates, in case
+ // the same candidate is seen again. (SS stored as a String.)
+ // Q: Why do we care? If we see the same candidate again, aren't we going
+ // to ignore it? So, why do we care about the SS of this repeat candidate?
+ // A: A "repeat" candidate may not be a repeat candidate in later
+ // iterations if the user specifies a value for prevMERTIterations
+ // that causes MERT to skip candidates from early iterations.
+
- double[] currFeatVal = new double[1 + numParams];
+ String[] featVal_str;
+
+ int totalCandidateCount = 0;
+
+ // new candidate size for each sentence
+ int[] sizeUnknown_currIt = new int[numSentences];
+
+ for (int i = 0; i < numSentences; ++i) {
+ // process candidates from previous iterations
+ // low efficiency? for each iteration, it reads in all previous iteration outputs
+ // therefore a lot of overlapping jobs
+ // this is an easy implementation to deal with the situation in which user only specified
+ // "previt" and hopes to consider only the previous previt
+ // iterations, then for each iteration the existing candadites will be different
+ for (int it = firstIt; it < iteration; ++it) {
+ // Why up to but *excluding* iteration?
+ // Because the last iteration is handled a little differently, since
+ // the SS must be calculated (and the corresponding file created),
+ // which is not true for previous iterations.
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // note that in all temp files, "||||||" is a separator between 2 n-best lists
+
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence, feature values,
+ // and sufficient statistics from the various temp files
+
+ // read one line of temp.sent, temp.feat, temp.stats from iteration it
+ sents_str = inFile_sents[it].readLine();
+ feats_str = inFile_feats[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1; // move on to the next n-best list
+ } else if (!existingCandStats.containsKey(sents_str)) // if this candidate does not
+ // exist
+ {
+ outFile_statsMergedKnown.println(stats_str);
+
+ // save feats & stats
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ // extract feature value
+ featVal_str = feats_str.split("\\s+");
+
- if (feats_str.indexOf('=') != -1) {
- for (String featurePair : featVal_str) {
- String[] pair = featurePair.split("=");
- String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
- }
- }
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+ newCandidatesAdded[it] += 1;
+
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ outFile_statsMergedKnown.println("||||||");
+
+ // ---------- end of processing previous iterations ----------
+ // ---------- now start processing new candidates ----------
+
+ // now process the candidates of the current iteration
+ // now determine the new candidates of the current iteration
+
+ /*
+ * remember: BufferedReader inFile_sentsCurrIt BufferedReader inFile_featsCurrIt
+ * PrintWriter outFile_statsCurrIt
+ */
+
+ String[] sentsCurrIt_currSrcSent = new String[sizeOfNBest + 1];
+
+ Vector<String> unknownCands_V = new Vector<>();
+ // which candidates (of the i'th source sentence) have not been seen before
+ // this iteration?
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ // Why up to and *including* sizeOfNBest?
+ // So that it would read the "||||||" separator even if there is
+ // a complete list of sizeOfNBest candidates.
+
+ // for the nth candidate for the ith sentence, read the sentence,
+ // and store it in the sentsCurrIt_currSrcSent array
+
+ sents_str = inFile_sentsCurrIt.readLine(); // read one candidate from the current
+ // iteration
+ sentsCurrIt_currSrcSent[n] = sents_str; // Note: possibly "||||||"
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ unknownCands_V.add(sents_str); // NEW CANDIDATE FROM THIS ITERATION
+ writeLine(sents_str, outFile_unknownCands);
+ outFile_unknownIndices.println(i); // INDEX OF THE NEW CANDIDATES
+ newCandidatesAdded[iteration] += 1;
+ existingCandStats.put(sents_str, "U"); // i.e. unknown
+ // we add sents_str to avoid duplicate entries in unknownCands_V
+ }
+ } // for (n)
+
+ // only compute suff stats for new candidates
+ // now unknownCands_V has the candidates for which we need to calculate
+ // sufficient statistics (for the i'th source sentence)
+ int sizeUnknown = unknownCands_V.size();
+ sizeUnknown_currIt[i] = sizeUnknown;
+
+ existingCandStats.clear();
+
+ } // for (i) each sentence
+
+ // ---------- end of merging candidates stats from previous iterations
+ // and finding new candidates ------------
+
+ /*
+ * int[][] newSuffStats = null; if (!statsCurrIt_exists && sizeUnknown > 0) { newSuffStats =
+ * evalMetric.suffStats(unknownCands, indices); }
+ */
+
+ outFile_statsMergedKnown.close();
+ outFile_unknownCands.close();
+ outFile_unknownIndices.close();
+
+ // want to re-open all temp files and start from scratch again?
+ for (int it = firstIt; it < iteration; ++it) // previous iterations temp files
+ {
+ inFile_sents[it].close();
+ inFile_stats[it].close();
+
+ InputStream inStream_sents, inStream_stats;
+ if (compressFiles == 0) {
+ inStream_sents = new FileInputStream(tmpDirPrefix + "temp.sents.it" + it);
+ inStream_stats = new FileInputStream(tmpDirPrefix + "temp.stats.it" + it);
+ } else {
+ inStream_sents = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.sents.it"
+ + it + ".gz"));
+ inStream_stats = new GZIPInputStream(new FileInputStream(tmpDirPrefix + "temp.stats.it"
+ + it + ".gz"));
+ }
+
+ inFile_sents[it] = new BufferedReader(new InputStreamReader(inStream_sents, "utf8"));
+ inFile_stats[it] = new BufferedReader(new InputStreamReader(inStream_stats, "utf8"));
+ }
+
+ inFile_sentsCurrIt.close();
+ // current iteration temp files
+ if (compressFiles == 0) {
+ inStream_sentsCurrIt = new FileInputStream(tmpDirPrefix + "temp.sents.it" + iteration);
+ } else {
+ inStream_sentsCurrIt = new GZIPInputStream(new FileInputStream(tmpDirPrefix
+ + "temp.sents.it" + iteration + ".gz"));
+ }
+ inFile_sentsCurrIt = new BufferedReader(new InputStreamReader(inStream_sentsCurrIt, "utf8"));
+
+ // calculate SS for unseen candidates and write them to file
+ FileInputStream inStream_statsCurrIt_unknown = null;
+ BufferedReader inFile_statsCurrIt_unknown = null;
+
+ if (!statsCurrIt_exists && newCandidatesAdded[iteration] > 0) {
+ // create the file...
+ evalMetric.createSuffStatsFile(tmpDirPrefix + "temp.currIt.unknownCands", tmpDirPrefix
+ + "temp.currIt.unknownIndices", tmpDirPrefix + "temp.stats.unknown", sizeOfNBest);
+
+ // ...and open it
+ inStream_statsCurrIt_unknown = new FileInputStream(tmpDirPrefix + "temp.stats.unknown");
+ inFile_statsCurrIt_unknown = new BufferedReader(new InputStreamReader(
+ inStream_statsCurrIt_unknown, "utf8"));
+ }
+
+ // open mergedKnown file
+ // newly created by the big loop above
+ FileInputStream instream_statsMergedKnown = new FileInputStream(tmpDirPrefix
+ + "temp.stats.mergedKnown");
+ BufferedReader inFile_statsMergedKnown = new BufferedReader(new InputStreamReader(
+ instream_statsMergedKnown, "utf8"));
+
+ // num of features before observing new firing features from this iteration
+ numParamsOld = numParams;
+
+ for (int i = 0; i < numSentences; ++i) {
+ // reprocess candidates from previous iterations
+ for (int it = firstIt; it < iteration; ++it) {
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sents[it].readLine();
+ stats_str = inFile_stats[it].readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+ existingCandStats.put(sents_str, stats_str);
+ } // if unseen candidate
+ } // for (n)
+ } // for (it)
+
+ // copy relevant portion from mergedKnown to the merged file
+ String line_mergedKnown = inFile_statsMergedKnown.readLine();
+ while (!line_mergedKnown.equals("||||||")) {
+ outFile_statsMerged.println(line_mergedKnown);
+ line_mergedKnown = inFile_statsMergedKnown.readLine();
+ }
+
+ int[] stats = new int[suffStatsCount];
+
+ for (int n = 0; n <= sizeOfNBest; ++n) {
+ sents_str = inFile_sentsCurrIt.readLine();
+ feats_str = inFile_featsCurrIt.readLine();
+
+ if (sents_str.equals("||||||")) {
+ n = sizeOfNBest + 1;
+ } else if (!existingCandStats.containsKey(sents_str)) {
+
+ if (!statsCurrIt_exists) {
+ stats_str = inFile_statsCurrIt_unknown.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+
+ outFile_statsCurrIt.println(stats_str);
+ } else {
+ stats_str = inFile_statsCurrIt.readLine();
+
+ String[] temp_stats = stats_str.split("\\s+");
+ for (int s = 0; s < suffStatsCount; ++s) {
+ stats[s] = Integer.parseInt(temp_stats[s]);
+ }
+ }
+
+ outFile_statsMerged.println(stats_str);
+
+ // save feats & stats
+ // System.out.println(sents_str+" "+feats_str);
+
+ feat_hash[i].put(sents_str, feats_str);
+ stats_hash[i].put(sents_str, stats_str);
+
+ featVal_str = feats_str.split("\\s+");
+
+ if (feats_str.indexOf('=') != -1) {
+ for (String featurePair : featVal_str) {
+ String[] pair = featurePair.split("=");
+ String name = pair[0];
- Double value = Double.parseDouble(pair[1]);
+ int featId = Vocabulary.id(name);
+
+ // need to identify newly fired feats here
+ // in this case currFeatVal is not given the value
+ // of the new feat, since the corresponding weight is
+ // initialized as zero anyway
+ if (featId > numParams) {
+ ++numParams;
+ lambda.add(0d);
+ }
+ }
+ }
+ existingCandStats.put(sents_str, stats_str);
+ candCount[i] += 1;
+
+ // newCandidatesAdded[iteration] += 1;
+ // moved to code above detecting new candidates
+ } else {
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else {
+ // write SS to outFile_statsCurrIt
+ stats_str = existingCandStats.get(sents_str);
+ outFile_statsCurrIt.println(stats_str);
+ }
+ }
+
+ } // for (n)
+
+ // now d = sizeUnknown_currIt[i] - 1
+
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.readLine();
+ else
+ outFile_statsCurrIt.println("||||||");
+
+ existingCandStats.clear();
+ totalCandidateCount += candCount[i];
+
+ // output sentence progress
+ if ((i + 1) % 500 == 0) {
+ print((i + 1) + "\n" + " ", 1);
+ } else if ((i + 1) % 100 == 0) {
+ print("+", 1);
+ } else if ((i + 1) % 25 == 0) {
+ print(".", 1);
+ }
+
+ } // for (i)
+
+ inFile_statsMergedKnown.close();
+ outFile_statsMerged.close();
+
+ // for testing
+ /*
+ * int total_sent = 0; for( int i=0; i<numSentences; i++ ) {
+ * System.out.println(feat_hash[i].size()+" "+candCount[i]); total_sent +=
+ * feat_hash[i].size(); feat_hash[i].clear(); }
+ * System.out.println("----------------total sent: "+total_sent); total_sent = 0; for( int
+ * i=0; i<numSentences; i++ ) { System.out.println(stats_hash[i].size()+" "+candCount[i]);
+ * total_sent += stats_hash[i].size(); stats_hash[i].clear(); }
+ * System.out.println("*****************total sent: "+total_sent);
+ */
+
+ println("", 1); // finish progress line
+
+ for (int it = firstIt; it < iteration; ++it) {
+ inFile_sents[it].close();
+ inFile_feats[it].close();
+ inFile_stats[it].close();
+ }
+
+ inFile_sentsCurrIt.close();
+ inFile_featsCurrIt.close();
+ if (statsCurrIt_exists)
+ inFile_statsCurrIt.close();
+ else
+ outFile_statsCurrIt.close();
+
+ if (compressFiles == 1 && !statsCurrIt_exists) {
+ gzipFile(tmpDirPrefix + "temp.stats.it" + iteration);
+ }
+
+ // clear temp files
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownCands");
+ deleteFile(tmpDirPrefix + "temp.currIt.unknownIndices");
+ deleteFile(tmpDirPrefix + "temp.stats.unknown");
+ deleteFile(tmpDirPrefix + "temp.stats.mergedKnown");
+
+ // cleanupMemory();
+
+ println("Processed " + totalCandidateCount + " distinct candidates " + "(about "
+ + totalCandidateCount / numSentences + " per sentence):", 1);
+ for (int it = firstIt; it <= iteration; ++it) {
+ println("newCandidatesAdded[it=" + it + "] = " + newCandidatesAdded[it] + " (about "
+ + newCandidatesAdded[it] / numSentences + " per sentence)", 1);
+ }
+
+ println("", 1);
+
+ println("Number of features observed so far: " + numParams);
+ println("", 1);
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ // n-best list converges
+ if (newCandidatesAdded[iteration] == 0) {
+ if (!oneModificationPerIteration) {
+ println("No new candidates added in this iteration; exiting AdaGrad.", 1);
+ println("", 1);
+ println("--- AdaGrad iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ if (returnBest) {
+ // note that bestLambda.size() <= lambda.size()
+ for (int p = 1; p < bestLambda.size(); ++p)
+ lambda.set(p, bestLambda.get(p));
+ // and set the rest of lambda to be 0
+ for (int p = 0; p < lambda.size() - bestLambda.size(); ++p)
+ lambda.set(p + bestLambda.size(), 0d);
+ }
+
+ return null; // this means that the old values should be kept by the caller
+ } else {
+ println("Note: No new candidates added in this iteration.", 1);
+ }
+ }
+
+ /************* start optimization **************/
+
+ /*
+ * for( int v=1; v<initialLambda[1].length; v++ ) System.out.print(initialLambda[1][v]+" ");
+ * System.exit(0);
+ */
+
+ Optimizer.sentNum = numSentences; // total number of training sentences
+ Optimizer.needShuffle = needShuffle;
+ Optimizer.adagradIter = adagradIter;
+ Optimizer.oraSelectMode = oraSelectMode;
+ Optimizer.predSelectMode = predSelectMode;
+ Optimizer.needAvg = needAvg;
+ // Optimizer.sentForScale = sentForScale;
+ Optimizer.scoreRatio = scoreRatio;
+ Optimizer.evalMetric = evalMetric;
+ Optimizer.normalizationOptions = normalizationOptions;
+ Optimizer.needScale = needScale;
+ Optimizer.regularization = regularization;
+ Optimizer.batchSize = batchSize;
+ Optimizer.eta = eta;
+ Optimizer.lam = lam;
+
+ // if need to use bleu stats history
+ if (iteration == 1) {
+ if (evalMetric.get_metricName().equals("BLEU") && usePseudoBleu) {
+ Optimizer.initBleuHistory(numSentences, evalMetric.get_suffStatsCount());
+ Optimizer.usePseudoBleu = usePseudoBleu;
+ Optimizer.R = R;
+ }
+ if (evalMetric.get_metricName().equals("TER-BLEU") && usePseudoBleu) {
+ Optimizer.initBleuHistory(numSentences, evalMetric.get_suffStatsCount() - 2); // Stats
+ // count of
+ // TER=2
+ Optimizer.usePseudoBleu = usePseudoBleu;
+ Optimizer.R = R;
+ }
+ }
+
+ Vector<String> output = new Vector<>();
+
+ // note: initialLambda[] has length = numParamsOld
+ // augmented with new feature weights, initial values are 0
+ double[] initialLambdaNew = new double[1 + numParams];
+ System.arraycopy(initialLambda, 1, initialLambdaNew, 1, numParamsOld);
+
+ // finalLambda[] has length = numParams (considering new features)
+ double[] finalLambda = new double[1 + numParams];
+
+ Optimizer opt = new Optimizer(output, isOptimizable, initialLambdaNew, feat_hash, stats_hash);
+ finalLambda = opt.runOptimizer();
+
+ if (returnBest) {
+ double metricScore = opt.getMetricScore();
+ if (!evalMetric.getToBeMinimized()) {
+ if (metricScore > prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ } else {
+ if (metricScore < prevMetricScore) {
+ prevMetricScore = metricScore;
+ for (int p = 1; p < bestLambda.size(); ++p)
+ bestLambda.set(p, finalLambda[p]);
+ if (1 + numParams > bestLambda.size()) {
+ for (int p = bestLambda.size(); p <= numParams; ++p)
+ bestLambda.add(p, finalLambda[p]);
+ }
+ }
+ }
+ }
+
+ // System.out.println(finalLambda.length);
+ // for( int i=0; i<finalLambda.length-1; i++ )
+ // System.out.println(finalLambda[i+1]);
+
+ /************* end optimization **************/
+
+ for (String anOutput : output)
+ println(anOutput);
+
+ // check if any parameter has been updated
+ boolean anyParamChanged = false;
+ boolean anyParamChangedSignificantly = false;
+
+ for (int c = 1; c <= numParams; ++c) {
+ if (finalLambda[c] != lambda.get(c)) {
+ anyParamChanged = true;
+ }
+ if (Math.abs(finalLambda[c] - lambda.get(c)) > stopSigValue) {
+ anyParamChangedSignificantly = true;
+ }
+ }
+
+ // System.arraycopy(finalLambda,1,lambda,1,numParams);
+
+ println("--- AdaGrad iteration #" + iteration + " ending @ " + (new Date()) + " ---", 1);
+ println("", 1);
+
+ if (!anyParamChanged) {
+ println("No parameter value changed in this iteration; exiting AdaGrad.", 1);
+ println("", 1);
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // was an early stopping criterion satisfied?
+ boolean critSatisfied = false;
+ if (!anyParamChangedSignificantly && stopSigValue >= 0) {
+ println("Note: No parameter value changed significantly " + "(i.e. by more than "
+ + stopSigValue + ") in this iteration.", 1);
+ critSatisfied = true;
+ }
+
+ if (critSatisfied) {
+ ++earlyStop;
+ println("", 1);
+ } else {
+ earlyStop = 0;
+ }
+
+ // if min number of iterations executed, investigate if early exit should happen
+ if (iteration >= minIts && earlyStop >= stopMinIts) {
+ println("Some early stopping criteria has been observed " + "in " + stopMinIts
+ + " consecutive iterations; exiting AdaGrad.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop preemptively
+ }
+
+ // if max number of iterations executed, exit
+ if (iteration >= maxIts) {
+ println("Maximum number of AdaGrad iterations reached; exiting AdaGrad.", 1);
+ println("", 1);
+
+ if (returnBest) {
+ for (int f = 1; f <= bestLambda.size() - 1; ++f)
+ lambda.set(f, bestLambda.get(f));
+ } else {
+ for (int f = 1; f <= numParams; ++f)
+ lambda.set(f, finalLambda[f]);
+ }
+
+ break; // exit for (iteration) loop
+ }
+
+ // use the new wt vector to decode the next iteration
+ // (interpolation with previous wt vector)
+ double interCoef = 1.0; // no interpolation for now
+ for (int i = 1; i <= numParams; i++)
+ lambda.set(i, interCoef * finalLambda[i] + (1 - interCoef) * lambda.get(i));
+
+ println("Next iteration will decode with lambda: " + lambdaToString(lambda), 1);
+ println("", 1);
+
+ // printMemoryUsage();
+ for (int i = 0; i < numSentences; ++i) {
+ suffStats_array[i].clear();
+ }
+ // cleanupMemory();
+ // println("",2);
+
+ retA[2] = 0; // i.e. this should NOT be the last iteration
+ done = true;
+
+ } // while (!done) // NOTE: this "loop" will only be carried out once
+
+ // delete .temp.stats.merged file, since it is not needed in the next
+ // iteration (it will be recreated from scratch)
+ deleteFile(tmpDirPrefix + "temp.stats.merged");
+
+ retA[0] = FINAL_score;
+ retA[1] = earlyStop;
+ return retA;
+
+ } // run_single_iteration
+
+ private String lambdaToString(ArrayList<Double> lambdaA) {
+ String retStr = "{";
+ int featToPrint = numParams > 15 ? 15 : numParams;
+ // print at most the first 15 features
+
+ retStr += "(listing the first " + featToPrint + " lambdas)";
+ for (int c = 1; c <= featToPrint - 1; ++c) {
+ retStr += "" + String.format("%.4f", lambdaA.get(c)) + ", ";
+ }
+ retStr += "" + String.format("%.4f", lambdaA.get(numParams)) + "}";
+
+ return retStr;
+ }
+
+ private String[] run_decoder(int iteration) {
+ String[] retSA = new String[2];
+
+ // retsa saves the output file name(nbest-file)
+ // and the decoder type
+
+ // [0] name of file to be processed
+ // [1] indicates how the output file was obtained:
+ // 1: external decoder
+ // 2: fake decoder
+ // 3: internal decoder
+
+ // use fake decoder
+ if (fakeFileNameTemplate != null
+ && fileExists(fakeFileNamePrefix + iteration + fakeFileNameSuffix)) {
+ String fakeFileName = fakeFileNamePrefix + iteration + fakeFileNameSuffix;
+ println("Not running decoder; using " + fakeFileName + " instead.", 1);
+ /*
+ * if (fakeFileName.endsWith(".gz")) { copyFile(fakeFileName,decoderOutFileName+".gz");
+ * gunzipFile(decoderOutFileName+".gz"); } else { copyFile(fakeFileName,decoderOutFileName); }
+ */
+ retSA[0] = fakeFileName;
+ retSA[1] = "2";
+
+ } else {
+ println("Running external decoder...", 1);
+
+ try {
+ ArrayList<String> cmd = new ArrayList<>();
+ cmd.add(decoderCommandFileName);
+
+ if (passIterationToDecoder)
+ cmd.add(Integer.toString(iteration));
+
+ ProcessBuilder pb = new ProcessBuilder(cmd);
+ // this merges the error and output streams of the subprocess
+ pb.redirectErrorStream(true);
+ Process p = pb.start();
+
+ // capture the sub-command's output
+ new StreamGobbler(p.getInputStream(), decVerbosity).start();
+
+ int decStatus = p.waitFor();
+ if (decStatus != validDecoderExitValue) {
+ throw new RuntimeException("Call to decoder returned " + decStatus + "; was expecting "
+ + validDecoderExitValue + ".");
+ }
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+
+ retSA[0] = decoderOutFileName;
+ retSA[1] = "1";
+
+ }
+
+ return retSA;
+ }
+
+ private void produceTempFiles(String nbestFileName, int iteration) {
+ try {
+ String sentsFileName = tmpDirPrefix + "temp.sents.it" + iteration;
+ String featsFileName = tmpDirPrefix + "temp.feats.it" + iteration;
+
+ FileOutputStream outStream_sents = new FileOutputStream(sentsFileName, false);
+ OutputStreamWriter outStreamWriter_sents = new OutputStreamWriter(outStream_sents, "utf8");
+ BufferedWriter outFile_sents = new BufferedWriter(outStreamWriter_sents);
+
+ PrintWriter outFile_feats = new PrintWriter(featsFileName);
+
+ InputStream inStream_nbest = null;
+ if (nbestFileName.endsWith(".gz")) {
+ inStream_nbest = new GZIPInputStream(new FileInputStream(nbestFileName));
+ } else {
+ inStream_nbest = new FileInputStream(nbestFileName);
+ }
+ BufferedReader inFile_nbest = new BufferedReader(
+ new InputStreamReader(inStream_nbest, "utf8"));
+
+ String line; // , prevLine;
+ String candidate_str = "";
+ String feats_str = "";
+
+ int i = 0;
+ int n = 0;
+ line = inFile_nbest.readLine();
+
+ while (line != null) {
+
+ /*
+ * line format:
+ *
+ * i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val
+ * .*
+ */
+
+ // in a well formed file, we'd find the nth candidate for the ith sentence
+
+ int read_i = Integer.parseInt((line.substring(0, line.indexOf("|||"))).trim());
+
+ if (read_i != i) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = (line.substring(line.indexOf("|||") + 3)).trim(); // get rid of initial text
+
+ candidate_str = (line.substring(0, line.indexOf("|||"))).trim();
+ feats_str = (line.substring(line.indexOf("|||") + 3)).trim();
+ // get rid of candidate string
+
+ int junk_i = feats_str.indexOf("|||");
+ if (junk_i >= 0) {
+ feats_str = (feats_str.substring(0, junk_i)).trim();
+ }
+
+ writeLine(normalize(candidate_str, textNormMethod), outFile_sents);
+ outFile_feats.println(feats_str);
+
+ ++n;
+ if (n == sizeOfNBest) {
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ n = 0;
+ ++i;
+ }
+
+ line = inFile_nbest.readLine();
+ }
+
+ if (i != numSentences) { // last sentence had too few candidates
+ writeLine("||||||", outFile_sents);
+ outFile_feats.println("||||||");
+ }
+
+ inFile_nbest.close();
+ outFile_sents.close();
+ outFile_feats.close();
+
+ if (compressFiles == 1) {
+ gzipFile(sentsFileName);
+ gzipFile(featsFileName);
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ private void createConfigFile(ArrayList<Double> params, String cfgFileName,
+ String templateFileName) {
+ try {
+ // i.e. create cfgFileName, which is similar to templateFileName, but with
+ // params[] as parameter values
+
+ BufferedReader inFile = new BufferedReader(new FileReader(templateFileName));
+ PrintWriter outFile = new PrintWriter(cfgFileName);
+
- BufferedReader inFeatDefFile = null;
- PrintWriter outFeatDefFile = null;
+ int origFeatNum = 0; // feat num in the template file
+
+ String line = inFile.readLine();
+ while (line != null) {
+ int c_match = -1;
+ for (int c = 1; c <= numParams; ++c) {
+ if (line.startsWith(Vocabulary.word(c) + " ")) {
+ c_match = c;
+ ++origFeatNum;
+ break;
+ }
+ }
+
+ if (c_match == -1) {
+ outFile.println(line);
+ } else {
+ if (Math.abs(params.get(c_match)) > 1e-20)
+ outFile.println(Vocabulary.word(c_match) + " " + params.get(c_match));
+ }
+
+ line = inFile.readLine();
+ }
+
+ // now append weights of new features
+ for (int c = origFeatNum + 1; c <= numParams; ++c) {
+ if (Math.abs(params.get(c)) > 1e-20)
+ outFile.println(Vocabulary.word(c) + " " + params.get(c));
+ }
+
+ inFile.close();
+ outFile.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void processParamFile() {
+ // process parameter file
+ Scanner inFile_init = null;
+ try {
+ inFile_init = new Scanner(new FileReader(paramsFileName));
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ String dummy = "";
+
+ // initialize lambda[] and other related arrays
+ for (int c = 1; c <= numParams; ++c) {
+ // skip parameter name
+ while (!dummy.equals("|||")) {
+ dummy = inFile_init.next();
+ }
+
+ // read default value
+ lambda.set(c, inFile_init.nextDouble());
+ defaultLambda[c] = lambda.get(c);
+
+ // read isOptimizable
+ dummy = inFile_init.next();
+ if (dummy.equals("Opt")) {
+ isOptimizable[c] = true;
+ } else if (dummy.equals("Fix")) {
+ isOptimizable[c] = false;
+ } else {
+ throw new RuntimeException("Unknown isOptimizable string " + dummy + " (must be either Opt or Fix)");
+ }
+
+ if (!isOptimizable[c]) { // skip next two values
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ } else {
+ // the next two values are not used, only to be consistent with ZMERT's params file format
+ dummy = inFile_init.next();
+ dummy = inFile_init.next();
+ // set minRandValue[c] and maxRandValue[c] (range for random values)
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("minRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ minRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ dummy = inFile_init.next();
+ if (dummy.equals("-Inf") || dummy.equals("+Inf")) {
+ throw new RuntimeException("maxRandValue[" + c + "] cannot be -Inf or +Inf!");
+ } else {
+ maxRandValue[c] = Double.parseDouble(dummy);
+ }
+
+ // check for illogical values
+ if (minRandValue[c] > maxRandValue[c]) {
+ throw new RuntimeException("minRandValue[" + c + "]=" + minRandValue[c] + " > " + maxRandValue[c]
+ + "=maxRandValue[" + c + "]!");
+ }
+
+ // check for odd values
+ if (minRandValue[c] == maxRandValue[c]) {
+ println("Warning: lambda[" + c + "] has " + "minRandValue = maxRandValue = "
+ + minRandValue[c] + ".", 1);
+ }
+ } // if (!isOptimizable[c])
+
+ /*
+ * precision[c] = inFile_init.nextDouble(); if (precision[c] < 0) { println("precision[" + c +
+ * "]=" + precision[c] + " < 0! Must be non-negative."); System.exit(21); }
+ */
+
+ }
+
+ // set normalizationOptions[]
+ String origLine = "";
+ while (origLine != null && origLine.length() == 0) {
+ origLine = inFile_init.nextLine();
+ }
+
+ // How should a lambda[] vector be normalized (before decoding)?
+ // nO[0] = 0: no normalization
+ // nO[0] = 1: scale so that parameter nO[2] has absolute value nO[1]
+ // nO[0] = 2: scale so that the maximum absolute value is nO[1]
+ // nO[0] = 3: scale so that the minimum absolute value is nO[1]
+ // nO[0] = 4: scale so that the L-nO[1] norm equals nO[2]
+
+ // normalization = none
+ // normalization = absval 1 lm
+ // normalization = maxabsval 1
+ // normalization = minabsval 1
+ // normalization = LNorm 2 1
+
+ dummy = (origLine.substring(origLine.indexOf("=") + 1)).trim();
+ String[] dummyA = dummy.split("\\s+");
+
+ if (dummyA[0].equals("none")) {
+ normalizationOptions[0] = 0;
+ } else if (dummyA[0].equals("absval")) {
+ normalizationOptions[0] = 1;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ String pName = dummyA[2];
+ for (int i = 3; i < dummyA.length; ++i) { // in case parameter name has multiple words
+ pName = pName + " " + dummyA[i];
+ }
+ normalizationOptions[2] = Vocabulary.id(pName);
+
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the absval normalization method must be positive.");
+ }
+ if (normalizationOptions[2] == 0) {
+ throw new RuntimeException("Unrecognized feature name " + normalizationOptions[2]
+ + " for absval normalization method.");
+ }
+ } else if (dummyA[0].equals("maxabsval")) {
+ normalizationOptions[0] = 2;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the maxabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("minabsval")) {
+ normalizationOptions[0] = 3;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ if (normalizationOptions[1] <= 0) {
+ throw new RuntimeException("Value for the minabsval normalization method must be positive.");
+ }
+ } else if (dummyA[0].equals("LNorm")) {
+ normalizationOptions[0] = 4;
+ normalizationOptions[1] = Double.parseDouble(dummyA[1]);
+ normalizationOptions[2] = Double.parseDouble(dummyA[2]);
+ if (normalizationOptions[1] <= 0 || normalizationOptions[2] <= 0) {
+ throw new RuntimeException("Both values for the LNorm normalization method must be positive.");
+ }
+ } else {
+ throw new RuntimeException("Unrecognized normalization method " + dummyA[0] + "; "
+ + "must be one of none, absval, maxabsval, and LNorm.");
+ } // if (dummyA[0])
+
+ inFile_init.close();
+ } // processParamFile()
+
+ private void processDocInfo() {
+ // sets numDocuments and docOfSentence[]
+ docOfSentence = new int[numSentences];
+
+ if (docInfoFileName == null) {
+ for (int i = 0; i < numSentences; ++i)
+ docOfSentence[i] = 0;
+ numDocuments = 1;
+ } else {
+
+ try {
+
+ // 4 possible formats:
+ // 1) List of numbers, one per document, indicating # sentences in each document.
+ // 2) List of "docName size" pairs, one per document, indicating name of document and #
+ // sentences.
+ // 3) List of docName's, one per sentence, indicating which doument each sentence belongs
+ // to.
+ // 4) List of docName_number's, one per sentence, indicating which doument each sentence
+ // belongs to,
+ // and its order in that document. (can also use '-' instead of '_')
+
- int docInfoSize = countNonEmptyLines(docInfoFileName);
++ int docInfoSize = new ExistingUTF8EncodedTextFile(docInfoFileName).getNumberOfNonEmptyLines();
+
+ if (docInfoSize < numSentences) { // format #1 or #2
+ numDocuments = docInfoSize;
+ int i = 0;
+
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ String line = inFile.readLine();
+ boolean format1 = (!(line.contains(" ")));
+
+ for (int doc = 0; doc < numDocuments; ++doc) {
+
+ if (doc != 0)
+ line = inFile.readLine();
+
+ int docSize = 0;
+ if (format1) {
+ docSize = Integer.parseInt(line);
+ } else {
+ docSize = Integer.parseInt(line.split("\\s+")[1]);
+ }
+
+ for (int i2 = 1; i2 <= docSize; ++i2) {
+ docOfSentence[i] = doc;
+ ++i;
+ }
+
+ }
+
+ // now i == numSentences
+
+ inFile.close();
+
+ } else if (docInfoSize == numSentences) { // format #3 or #4
+
+ boolean format3 = false;
+
+ HashSet<String> seenStrings = new HashSet<>();
+ BufferedReader inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ // set format3 = true if a duplicate is found
+ String line = inFile.readLine();
+ if (seenStrings.contains(line))
+ format3 = true;
+ seenStrings.add(line);
+ }
+
+ inFile.close();
+
+ HashSet<String> seenDocNames = new HashSet<>();
+ HashMap<String, Integer> docOrder = new HashMap<>();
+ // maps a document name to the order (0-indexed) in which it was seen
+
+ inFile = new BufferedReader(new FileReader(docInfoFileName));
+ for (int i = 0; i < numSentences; ++i) {
+ String line = inFile.readLine();
+
+ String docName = "";
+ if (format3) {
+ docName = line;
+ } else {
+ int sep_i = Math.max(line.lastIndexOf('_'), line.lastIndexOf('-'));
+ docName = line.substring(0, sep_i);
+ }
+
+ if (!seenDocNames.contains(docName)) {
+ seenDocNames.add(docName);
+ docOrder.put(docName, seenDocNames.size() - 1);
+ }
+
+ int docOrder_i = docOrder.get(docName);
+
+ docOfSentence[i] = docOrder_i;
+
+ }
+
+ inFile.close();
+
+ numDocuments = seenDocNames.size();
+
+ } else { // badly formatted
+
+ }
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private boolean copyFile(String origFileName, String newFileName) {
+ try {
+ File inputFile = new File(origFileName);
+ File outputFile = new File(newFileName);
+
+ InputStream in = new FileInputStream(inputFile);
+ OutputStream out = new FileOutputStream(outputFile);
+
+ byte[] buffer = new byte[1024];
+ int len;
+ while ((len = in.read(buffer)) > 0) {
+ out.write(buffer, 0, len);
+ }
+ in.close();
+ out.close();
+
+ /*
+ * InputStream inStream = new FileInputStream(new File(origFileName)); BufferedReader inFile =
+ * new BufferedReader(new InputStreamReader(inStream, "utf8"));
+ *
+ * FileOutputStream outStream = new FileOutputStream(newFileName, false); OutputStreamWriter
+ * outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new
+ * BufferedWriter(outStreamWriter);
+ *
+ * String line; while(inFile.ready()) { line = inFile.readLine(); writeLine(line, outFile); }
+ *
+ * inFile.close(); outFile.close();
+ */
+ return true;
+ } catch (IOException e) {
+ LOG.error(e.getMessage(), e);
+ return false;
+ }
+ }
+
+ private void renameFile(String origFileName, String newFileName) {
+ if (fileExists(origFileName)) {
+ deleteFile(newFileName);
+ File oldFile = new File(origFileName);
+ File newFile = new File(newFileName);
+ if (!oldFile.renameTo(newFile)) {
+ println("Warning: attempt to rename " + origFileName + " to " + newFileName
+ + " was unsuccessful!", 1);
+ }
+ } else {
+ println("Warning: file " + origFileName + " does not exist! (in AdaGradCore.renameFile)", 1);
+ }
+ }
+
+ private void deleteFile(String fileName) {
+ if (fileExists(fileName)) {
+ File fd = new File(fileName);
+ if (!fd.delete()) {
+ println("Warning: attempt to delete " + fileName + " was unsuccessful!", 1);
+ }
+ }
+ }
+
+ private void writeLine(String line, BufferedWriter writer) throws IOException {
+ writer.write(line, 0, line.length());
+ writer.newLine();
+ writer.flush();
+ }
+
+ // need to re-write to handle different forms of lambda
+ public void finish() {
+ if (myDecoder != null) {
+ myDecoder.cleanUp();
+ }
+
+ // create config file with final values
+ createConfigFile(lambda, decoderConfigFileName + ".AdaGrad.final", decoderConfigFileName
+ + ".AdaGrad.orig");
+
+ // delete current decoder config file and decoder output
+ deleteFile(decoderConfigFileName);
+ deleteFile(decoderOutFileName);
+
+ // restore original name for config file (name was changed
+ // in initialize() so it doesn't get overwritten)
+ renameFile(decoderConfigFileName + ".AdaGrad.orig", decoderConfigFileName);
+
+ if (finalLambdaFileName != null) {
+ try {
+ PrintWriter outFile_lambdas = new PrintWriter(finalLambdaFileName);
+ for (int c = 1; c <= numParams; ++c) {
+ outFile_lambdas.println(Vocabulary.word(c) + " ||| " + lambda.get(c));
+ }
+ outFile_lambdas.close();
+
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ private String[] cfgFileToArgsArray(String fileName) {
+ checkFile(fileName);
+
+ Vector<String> argsVector = new Vector<>();
+
+ BufferedReader inFile = null;
+ try {
+ inFile = new BufferedReader(new FileReader(fileName));
+ String line, origLine;
+ do {
+ line = inFile.readLine();
+ origLine = line; // for error reporting purposes
+
+ if (line != null && line.length() > 0 && line.charAt(0) != '#') {
+
+ if (line.contains("#")) { // discard comment
+ line = line.substring(0, line.indexOf("#"));
+ }
+
+ line = line.trim();
+
+ // now line should look like "-xxx XXX"
+
+ /*
+ * OBSOLETE MODIFICATION //SPECIAL HANDLING FOR AdaGrad CLASSIFIER PARAMETERS String[]
+ * paramA = line.split("\\s+");
+ *
+ * if( paramA[0].equals("-classifierParams") ) { String classifierParam = ""; for(int p=1;
+ * p<=paramA.length-1; p++) classifierParam += paramA[p]+" ";
+ *
+ * if(paramA.length>=2) { String[] tmpParamA = new String[2]; tmpParamA[0] = paramA[0];
+ * tmpParamA[1] = classifierParam; paramA = tmpParamA; } else {
+ * println("Malformed line in config file:"); println(origLine); System.exit(70); } }//END
+ * MODIFICATION
+ */
+
+ // cmu modification(from meteor for zmert)
+ // Parse args
+ ArrayList<String> argList = new ArrayList<>();
+ StringBuilder arg = new StringBuilder();
+ boolean quoted = false;
+ for (int i = 0; i < line.length(); i++) {
+ if (Character.isWhitespace(line.charAt(i))) {
+ if (quoted)
+ arg.append(line.charAt(i));
+ else if (arg.length() > 0) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ } else if (line.charAt(i) == '\'') {
+ if (quoted) {
+ argList.add(arg.toString());
+ arg = new StringBuilder();
+ }
+ quoted = !quoted;
+ } else
+ arg.append(line.charAt(i));
+ }
+ if (arg.length() > 0)
+ argList.add(arg.toString());
+ // Create paramA
+ String[] paramA = new String[argList.size()];
+ for (int i = 0; i < paramA.length; paramA[i] = argList.get(i++))
+ ;
+ // END CMU MODIFICATION
+
+ if (paramA.length == 2 && paramA[0].charAt(0) == '-') {
+ argsVector.add(paramA[0]);
+ argsVector.add(paramA[1]);
+ } else if (paramA.length > 2 && (paramA[0].equals("-m") || paramA[0].equals("-docSet"))) {
+ // -m (metricName), -docSet are allowed to have extra optinos
+ Collections.addAll(argsVector, paramA);
+ } else {
+ String msg = "Malformed line in config file:" + origLine;
+ throw new RuntimeException(msg);
+ }
+
+ }
+ } while (line != null);
+
+ inFile.close();
+ } catch (FileNotFoundException e) {
+ println("AdaGrad configuration file " + fileName + " was not found!");
+ throw new RuntimeException(e);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ String[] argsArray = new String[argsVector.size()];
+
+ for (int i = 0; i < argsVector.size(); ++i) {
+ argsArray[i] = argsVector.elementAt(i);
+ }
+
+ return argsArray;
+ }
+
+ private void processArgsArray(String[] args) {
+ processArgsArray(args, true);
+ }
+
+ private void processArgsArray(String[] args, boolean firstTime) {
+ /* set default values */
+ // Relevant files
+ dirPrefix = null;
+ sourceFileName = null;
+ refFileName = "reference.txt";
+ refsPerSen = 1;
+ textNormMethod = 1;
+ paramsFileName = "params.txt";
+ docInfoFileName = null;
+ finalLambdaFileName = null;
+ // MERT specs
+ metricName = "BLEU";
+ metricName_display = metricName;
+ metricOptions = new String[2];
+ metricOptions[0] = "4";
+ metricOptions[1] = "closest";
+ docSubsetInfo = new int[7];
+ docSubsetInfo[0] = 0;
+ maxMERTIterations = 20;
+ prevMERTIterations = 20;
+ minMERTIterations = 5;
+ stopMinIts = 3;
+ stopSigValue = -1;
+ //
+ // /* possibly other early stopping criteria here */
+ //
+ numOptThreads = 1;
+ saveInterFiles = 3;
+ compressFiles = 0;
+ oneModificationPerIteration = false;
+ randInit = false;
+ seed = System.currentTimeMillis();
+ // useDisk = 2;
+ // Decoder specs
+ decoderCommandFileName = null;
+ passIterationToDecoder = false;
+ decoderOutFileName = "output.nbest";
+ validDecoderExitValue = 0;
+ decoderConfigFileName = "dec_cfg.txt";
+ sizeOfNBest = 100;
+ fakeFileNameTemplate = null;
+ fakeFileNamePrefix = null;
+ fakeFileNameSuffix = null;
+ // Output specs
+ verbosity = 1;
+ decVerbosity = 0;
+
+ int i = 0;
+
+ while (i < args.length) {
+ String option = args[i];
+ // Relevant files
+ if (option.equals("-dir")) {
+ dirPrefix = args[i + 1];
+ } else if (option.equals("-s")) {
+ sourceFileName = args[i + 1];
+ } else if (option.equals("-r")) {
+ refFileName = args[i + 1];
+ } else if (option.equals("-rps")) {
+ refsPerSen = Integer.parseInt(args[i + 1]);
+ if (refsPerSen < 1) {
+ throw new RuntimeException("refsPerSen must be positive.");
+ }
+ } else if (option.equals("-txtNrm")) {
+ textNormMethod = Integer.parseInt(args[i + 1]);
+ if (textNormMethod < 0 || textNormMethod > 4) {
+ throw new RuntimeException("textNormMethod should be between 0 and 4");
+ }
+ } else if (option.equals("-p")) {
+ paramsFileName = args[i + 1];
+ } else if (option.equals("-docInfo")) {
+ docInfoFileName = args[i + 1];
+ } else if (option.equals("-fin")) {
+ finalLambdaFileName = args[i + 1];
+ // MERT specs
+ } else if (option.equals("-m")) {
+ metricName = args[i + 1];
+ metricName_display = metricName;
+ if (EvaluationMetric.knownMetricName(metricName)) {
+ int optionCount = EvaluationMetric.metricOptionCount(metricName);
+ metricOptions = new String[optionCount];
+ for (int opt = 0; opt < optionCount; ++opt) {
+ metricOptions[opt] = args[i + opt + 2];
+ }
+ i += optionCount;
+ } else {
+ throw new RuntimeException("Unknown metric name " + metricName + ".");
+ }
+ } else if (option.equals("-docSet")) {
+ String method = args[i + 1];
+
+ if (method.equals("all")) {
+ docSubsetInfo[0] = 0;
+ i += 0;
+ } else if (method.equals("bottom")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 1;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 2;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("top")) {
+ String a = args[i + 2];
+ if (a.endsWith("d")) {
+ docSubsetInfo[0] = 3;
+ a = a.substring(0, a.indexOf("d"));
+ } else {
+ docSubsetInfo[0] = 4;
+ a = a.substring(0, a.indexOf("%"));
+ }
+ docSubsetInfo[5] = Integer.parseInt(a);
+ i += 1;
+ } else if (method.equals("window")) {
+ String a1 = args[i + 2];
+ a1 = a1.substring(0, a1.indexOf("d")); // size of window
+ String a2 = args[i + 4];
+ if (a2.indexOf("p") > 0) {
+ docSubsetInfo[0] = 5;
+ a2 = a2.substring(0, a2.indexOf("p"));
+ } else {
+
<TRUNCATED>
[17/17] incubator-joshua git commit: Merge branch 'master' into
7-with-master
Posted by mj...@apache.org.
Merge branch 'master' into 7-with-master
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/b0b70627
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/b0b70627
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/b0b70627
Branch: refs/heads/7
Commit: b0b70627225838a34508b8199c8b2156c1774ca4
Parents: bb3c300 840eb4c
Author: max thomas <ma...@maxthomas.io>
Authored: Tue Aug 30 16:24:24 2016 -0400
Committer: max thomas <ma...@maxthomas.io>
Committed: Tue Aug 30 16:24:24 2016 -0400
----------------------------------------------------------------------
.../org/apache/joshua/adagrad/AdaGradCore.java | 218 +------
.../org/apache/joshua/adagrad/Optimizer.java | 52 +-
.../joshua/corpus/syntax/ArraySyntaxTree.java | 20 +-
.../org/apache/joshua/decoder/ArgsParser.java | 24 +-
.../java/org/apache/joshua/decoder/Decoder.java | 33 +-
.../apache/joshua/decoder/JoshuaDecoder.java | 17 +-
.../decoder/StructuredTranslationFactory.java | 18 +-
.../org/apache/joshua/decoder/Translation.java | 33 +-
.../joshua/decoder/chart_parser/DotChart.java | 78 +--
.../apache/joshua/decoder/ff/TargetBigram.java | 5 +-
.../joshua/decoder/ff/fragmentlm/Tree.java | 103 ++--
.../org/apache/joshua/decoder/ff/lm/KenLM.java | 20 +-
.../joshua/decoder/ff/lm/buildin_lm/TrieLM.java | 168 ++----
.../joshua/decoder/ff/tm/CreateGlueGrammar.java | 42 +-
.../decoder/ff/tm/packed/PackedGrammar.java | 81 ++-
.../java/org/apache/joshua/lattice/Lattice.java | 54 +-
.../java/org/apache/joshua/metrics/CHRF.java | 97 ++--
.../java/org/apache/joshua/metrics/SARI.java | 117 ++--
.../java/org/apache/joshua/mira/MIRACore.java | 231 +-------
.../java/org/apache/joshua/mira/Optimizer.java | 26 +-
.../java/org/apache/joshua/pro/PROCore.java | 224 +-------
.../org/apache/joshua/tools/GrammarPacker.java | 86 +--
.../org/apache/joshua/tools/LabelPhrases.java | 84 +--
.../org/apache/joshua/tools/TestSetFilter.java | 83 +--
.../java/org/apache/joshua/util/BotMap.java | 94 ---
.../java/org/apache/joshua/util/Constants.java | 15 +-
.../org/apache/joshua/util/FileUtility.java | 261 +--------
.../org/apache/joshua/util/IntegerPair.java | 36 --
.../java/org/apache/joshua/util/ListUtil.java | 51 --
.../main/java/org/apache/joshua/util/Lists.java | 567 -------------------
.../org/apache/joshua/util/NullIterator.java | 65 ---
.../org/apache/joshua/util/QuietFormatter.java | 36 --
.../org/apache/joshua/util/ReverseOrder.java | 39 --
.../org/apache/joshua/util/SampledList.java | 69 ---
.../org/apache/joshua/util/SocketUtility.java | 144 -----
.../apache/joshua/util/encoding/Analyzer.java | 85 +--
.../util/encoding/FeatureTypeAnalyzer.java | 40 +-
.../util/io/ExistingUTF8EncodedTextFile.java | 77 +++
.../apache/joshua/util/io/IndexedReader.java | 13 +-
.../org/apache/joshua/util/io/LineReader.java | 123 ++--
.../org/apache/joshua/util/io/NullReader.java | 63 ---
.../java/org/apache/joshua/util/io/Reader.java | 11 +-
.../joshua/util/quantization/Quantizer.java | 48 +-
.../quantization/QuantizerConfiguration.java | 167 +++---
.../util/quantization/StatelessQuantizer.java | 26 +-
.../java/org/apache/joshua/zmert/MertCore.java | 205 +------
.../org/apache/joshua/packed/Benchmark.java | 28 +-
.../system/MultithreadedTranslationTests.java | 12 +-
48 files changed, 945 insertions(+), 3244 deletions(-)
----------------------------------------------------------------------