You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:55 UTC
[09/13] incubator-joshua git commit: fiddling
fiddling
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/f0728fac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/f0728fac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/f0728fac
Branch: refs/heads/morph
Commit: f0728fac70bbacc7b33937f4915a091e4cfe4e36
Parents: b121fc2
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:11:54 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:11:54 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 159 +++++++++----------
1 file changed, 79 insertions(+), 80 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/f0728fac/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index 701a72f..7982db4 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,7 +1,5 @@
package joshua.decoder.ff.morph;
-import java.io.BufferedReader;
-
/***
* This feature function scores a rule application by predicting, for each target word aligned with
* a source word, how likely the lexical translation is in context.
@@ -59,34 +57,56 @@ public class LexicalSharpener extends StatelessFF {
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
- if (parsedArgs.containsKey("model")) {
- String modelFile = parsedArgs.get("model");
- if (! new File(modelFile).exists()) {
- if (parsedArgs.getOrDefault("training-data", null) != null) {
- try {
- trainAll(parsedArgs.get("training-data"));
- } catch (FileNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- } else {
- System.err.println("* FATAL: no model and no training data.");
- System.exit(1);
- }
- } else {
- try {
- loadClassifiers(modelFile);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
+ if (parsedArgs.getOrDefault("training-data", null) != null) {
+ try {
+ trainAll(parsedArgs.get("training-data"));
+ } catch (FileNotFoundException e) {
+ System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
+ System.exit(1);
}
}
}
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public void trainAll(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Predictor>();
+
+ Decoder.LOG(1, "Reading " + dataFile);
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastSourceWord = null;
+ String examples = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String sourceWord = line.substring(0, line.indexOf(' '));
+ if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ examples = "";
+ }
+
+ examples += line + "\n";
+ lastSourceWord = sourceWord;
+ linesRead++;
+ }
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+
+ System.err.println(String.format("Read %d lines from training file", linesRead));
+ }
+
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
classifiers = (HashMap<Integer,Predictor>) ois.readObject();
@@ -113,6 +133,8 @@ public class LexicalSharpener extends StatelessFF {
@Override
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
+
+ System.err.println(String.format("RULE: %s", rule));
Map<Integer, List<Integer>> points = rule.getAlignmentMap();
for (int t: points.keySet()) {
@@ -120,12 +142,14 @@ public class LexicalSharpener extends StatelessFF {
if (source_indices.size() != 1)
continue;
- String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+ int targetID = rule.getEnglish()[t];
+ String targetWord = Vocabulary.word(targetID);
int s = i + source_indices.get(0);
Token sourceToken = sentence.getTokens().get(s);
String featureString = sourceToken.getAnnotationString().replace('|', ' ');
- Classification result = predict(sourceToken.getWord(), featureString);
+ Classification result = predict(sourceToken.getWord(), targetID, featureString);
+ System.out.println("RESULT: " + result.getLabeling());
if (result.bestLabelIsCorrect()) {
acc.add(String.format("%s_match", name), 1);
}
@@ -134,12 +158,12 @@ public class LexicalSharpener extends StatelessFF {
return null;
}
- public Classification predict(int id, String featureString) {
- String word = Vocabulary.word(id);
- if (classifiers.containsKey(id)) {
- Predictor predictor = classifiers.get(id);
+ public Classification predict(int sourceID, int targetID, String featureString) {
+ String word = Vocabulary.word(sourceID);
+ if (classifiers.containsKey(sourceID)) {
+ Predictor predictor = classifiers.get(sourceID);
if (predictor != null)
- return predictor.predict(featureString);
+ return predictor.predict(Vocabulary.word(targetID), featureString);
}
return null;
@@ -226,10 +250,10 @@ public class LexicalSharpener extends StatelessFF {
* @param features the set of features
* @return
*/
- public Classification predict(String features) {
- Instance instance = new Instance(features, null, null, null);
- // System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
- // System.err.println("PREDICT features = " + (String) instance.getData());
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, outcome, null, null);
+ System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
+ System.err.println("PREDICT features = " + (String) instance.getData());
if (classifier == null)
train();
@@ -239,8 +263,8 @@ public class LexicalSharpener extends StatelessFF {
}
public void train() {
- System.err.println(String.format("Word %s: training model", sourceWord));
- System.err.println(String.format(" Examples: %s", examples));
+// System.err.println(String.format("Word %s: training model", sourceWord));
+// System.err.println(String.format(" Examples: %s", examples));
StringReader reader = new StringReader(examples);
@@ -249,47 +273,21 @@ public class LexicalSharpener extends StatelessFF {
ClassifierTrainer trainer = new MaxEntTrainer();
classifier = trainer.train(instances);
- }
- }
-
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public void trainAll(String dataFile) throws FileNotFoundException {
-
- classifiers = new HashMap<Integer, Predictor>();
-
- Decoder.VERBOSE = 1;
- LineReader lineReader = null;
- try {
- lineReader = new LineReader(dataFile, true);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
+
+ System.err.println(String.format("Trained a model for %s with %d outcomes",
+ sourceWord, pipes.getTargetAlphabet().size()));
}
- String lastSourceWord = null;
- String examples = "";
- int linesRead = 0;
- for (String line : lineReader) {
- String sourceWord = line.substring(0, line.indexOf(' '));
- if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
- // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
- examples = "";
- }
-
- examples += line + "\n";
- lastSourceWord = sourceWord;
- linesRead++;
+ /**
+ * Returns the number of distinct outcomes. Requires the model to have been trained!
+ *
+ * @return
+ */
+ public int getNumOutcomes() {
+ if (classifier == null)
+ train();
+ return pipes.getTargetAlphabet().size();
}
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
-
- System.err.println(String.format("Read %d lines from training file", linesRead));
}
public static void example(String[] args) throws IOException, ClassNotFoundException {
@@ -345,10 +343,11 @@ public class LexicalSharpener extends StatelessFF {
Scanner stdin = new Scanner(System.in);
while(stdin.hasNextLine()) {
String line = stdin.nextLine();
- String[] tokens = line.split(" ", 2);
+ String[] tokens = line.split(" ", 3);
String sourceWord = tokens[0];
- String features = tokens[1];
- Classification result = ts.predict(Vocabulary.id(sourceWord), features);
+ String targetWord = tokens[1];
+ String features = tokens[2];
+ Classification result = ts.predict(Vocabulary.id(sourceWord), Vocabulary.id(targetWord), features);
if (result != null)
System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
else