You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/24 22:53:26 UTC
[05/18] incubator-joshua git commit: Model now serializes
Model now serializes
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/68b01bc1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/68b01bc1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/68b01bc1
Branch: refs/heads/morph
Commit: 68b01bc168298db382334e9f01bdf2992db85b01
Parents: 1c8aaa5
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 15:59:39 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 15:59:39 2016 -0400
----------------------------------------------------------------------
src/joshua/decoder/ff/LexicalSharpener.java | 184 ++++++-----------------
src/joshua/decoder/ff/MalletPredictor.java | 97 ++++++++++++
2 files changed, 143 insertions(+), 138 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/68b01bc1/src/joshua/decoder/ff/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/LexicalSharpener.java b/src/joshua/decoder/ff/LexicalSharpener.java
index 2c96f83..8671d57 100644
--- a/src/joshua/decoder/ff/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/LexicalSharpener.java
@@ -19,24 +19,16 @@ package joshua.decoder.ff;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
-import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
-import java.io.StringReader;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import cc.mallet.classify.*;
-import cc.mallet.pipe.*;
-import cc.mallet.pipe.iterator.CsvIterator;
-import cc.mallet.types.Alphabet;
-import cc.mallet.types.Instance;
-import cc.mallet.types.InstanceList;
-import cc.mallet.types.LabelAlphabet;
+import cc.mallet.types.Labeling;
import joshua.corpus.Vocabulary;
import joshua.decoder.Decoder;
import joshua.decoder.JoshuaConfiguration;
@@ -52,7 +44,8 @@ import joshua.util.io.LineReader;
public class LexicalSharpener extends StatelessFF {
- private HashMap<Integer,Predictor> classifiers = null;
+ private HashMap<String,MalletPredictor> classifiers = null;
+
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
@@ -63,6 +56,13 @@ public class LexicalSharpener extends StatelessFF {
System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
System.exit(1);
}
+ } else if (parsedArgs.containsKey("model")) {
+ try {
+ loadClassifiers(parsedArgs.get("model"));
+ } catch (ClassNotFoundException | IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
}
}
@@ -75,7 +75,7 @@ public class LexicalSharpener extends StatelessFF {
*/
public void trainAll(String dataFile) throws FileNotFoundException {
- classifiers = new HashMap<Integer, Predictor>();
+ classifiers = new HashMap<String, MalletPredictor>();
Decoder.LOG(1, "Reading " + dataFile);
LineReader lineReader = null;
@@ -92,7 +92,7 @@ public class LexicalSharpener extends StatelessFF {
for (String line : lineReader) {
String sourceWord = line.substring(0, line.indexOf(' '));
if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ classifiers.put(lastSourceWord, new MalletPredictor(lastSourceWord, examples));
// System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
examples = "";
}
@@ -101,18 +101,18 @@ public class LexicalSharpener extends StatelessFF {
lastSourceWord = sourceWord;
linesRead++;
}
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ classifiers.put(lastSourceWord, new MalletPredictor(lastSourceWord, examples));
System.err.println(String.format("Read %d lines from training file", linesRead));
}
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifiers = (HashMap<Integer,Predictor>) ois.readObject();
+ classifiers = (HashMap<String,MalletPredictor>) ois.readObject();
ois.close();
System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
- for (int key: classifiers.keySet()) {
+ for (String key: classifiers.keySet()) {
System.err.println(" " + key);
}
}
@@ -133,8 +133,6 @@ public class LexicalSharpener extends StatelessFF {
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
- System.err.println(String.format("RULE: %s", rule));
-
Map<Integer, List<Integer>> points = rule.getAlignmentMap();
for (int t: points.keySet()) {
List<Integer> source_indices = points.get(t);
@@ -142,27 +140,46 @@ public class LexicalSharpener extends StatelessFF {
continue;
int targetID = rule.getEnglish()[t];
- String targetWord = Vocabulary.word(targetID);
int s = i + source_indices.get(0);
Token sourceToken = sentence.getTokens().get(s);
String featureString = sourceToken.getAnnotationString().replace('|', ' ');
Classification result = predict(sourceToken.getWord(), targetID, featureString);
- System.out.println("RESULT: " + result.getLabeling());
- if (result.bestLabelIsCorrect()) {
- acc.add(String.format("%s_match", name), 1);
+ if (result != null) {
+ Labeling labeling = result.getLabeling();
+ int num = labeling.numLocations();
+ int predicted = Vocabulary.id(labeling.getBestLabel().toString());
+// System.err.println(String.format("LexicalSharpener: predicted %s (rule %s) %.5f",
+// labeling.getBestLabel().toString(), Vocabulary.word(targetID), Math.log(labeling.getBestValue())));
+ if (num > 1 && predicted == targetID) {
+ acc.add(String.format("%s_match_%s", name, getBin(num)), 1);
+ }
+ acc.add(String.format("%s_weight", name), (float) Math.log(labeling.getBestValue()));
}
}
return null;
}
+ private String getBin(int num) {
+ if (num == 2)
+ return "2";
+ else if (num <= 5)
+ return "3-5";
+ else if (num <= 10)
+ return "6-10";
+ else if (num <= 20)
+ return "11-20";
+ else
+ return "21+";
+ }
+
public Classification predict(int sourceID, int targetID, String featureString) {
String word = Vocabulary.word(sourceID);
- if (classifiers.containsKey(sourceID)) {
- Predictor predictor = classifiers.get(sourceID);
+ if (classifiers.containsKey(word)) {
+ MalletPredictor predictor = classifiers.get(word);
if (predictor != null)
- return predictor.predict(Vocabulary.word(targetID), featureString);
+ return predictor.predict(word, featureString);
}
return null;
@@ -212,112 +229,6 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
- public class Predictor {
-
- private SerialPipes pipes = null;
- private InstanceList instances = null;
- private String sourceWord = null;
- private String examples = null;
- private Classifier classifier = null;
-
- public Predictor(String word, String examples) {
- this.sourceWord = word;
- this.examples = examples;
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(pipes);
- }
-
- /**
- * Returns a Classification object a list of features. Uses "which" to determine which classifier
- * to use.
- *
- * @param which the classifier to use
- * @param features the set of features
- * @return
- */
- public Classification predict(String outcome, String features) {
- Instance instance = new Instance(features, outcome, null, null);
- System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
- System.err.println("PREDICT features = " + (String) instance.getData());
-
- if (classifier == null)
- train();
-
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
- return result;
- }
-
- public void train() {
-// System.err.println(String.format("Word %s: training model", sourceWord));
-// System.err.println(String.format(" Examples: %s", examples));
-
- StringReader reader = new StringReader(examples);
-
- // Constructs an instance with everything shoved into the data field
- instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- classifier = trainer.train(instances);
-
- System.err.println(String.format("Trained a model for %s with %d outcomes",
- sourceWord, pipes.getTargetAlphabet().size()));
- }
-
- /**
- * Returns the number of distinct outcomes. Requires the model to have been trained!
- *
- * @return
- */
- public int getNumOutcomes() {
- if (classifier == null)
- train();
- return pipes.getTargetAlphabet().size();
- }
- }
-
- public static void example(String[] args) throws IOException, ClassNotFoundException {
-
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- Alphabet dataAlphabet = new Alphabet();
- LabelAlphabet labelAlphabet = new LabelAlphabet();
-
- pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
- // Basically, SvmLight but with a custom (fixed) alphabet)
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
-
- FileReader reader1 = new FileReader("data.1");
- FileReader reader2 = new FileReader("data.2");
-
- SerialPipes pipes = new SerialPipes(pipeList);
- InstanceList instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
- instances.addThruPipe(new CsvIterator(reader1, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
- ClassifierTrainer trainer1 = new MaxEntTrainer();
- Classifier classifier1 = trainer1.train(instances);
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
- instances.addThruPipe(new CsvIterator(reader2, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
- ClassifierTrainer trainer2 = new MaxEntTrainer();
- Classifier classifier2 = trainer2.train(instances);
- }
-
public static void main(String[] args) throws IOException, ClassNotFoundException {
LexicalSharpener ts = new LexicalSharpener(null, args, null);
@@ -329,14 +240,11 @@ public class LexicalSharpener extends StatelessFF {
System.err.println("Training model from file " + dataFile);
ts.trainAll(dataFile);
-// if (args.length > 1)
-// modelFile = args[1];
-//
-// System.err.println("Writing model to file " + modelFile);
-// ts.saveClassifiers(modelFile);
-// } else {
-// System.err.println("Loading model from file " + modelFile);
-// ts.loadClassifiers(modelFile);
+ if (args.length > 1)
+ modelFile = args[1];
+
+ System.err.println("Writing model to file " + modelFile);
+ ts.saveClassifiers(modelFile);
}
Scanner stdin = new Scanner(System.in);
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/68b01bc1/src/joshua/decoder/ff/MalletPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/MalletPredictor.java b/src/joshua/decoder/ff/MalletPredictor.java
new file mode 100644
index 0000000..04c9d8c
--- /dev/null
+++ b/src/joshua/decoder/ff/MalletPredictor.java
@@ -0,0 +1,97 @@
+package joshua.decoder.ff;
+
+import java.io.Serializable;
+import java.io.StringReader;
+import java.util.ArrayList;
+
+import cc.mallet.classify.Classification;
+import cc.mallet.classify.Classifier;
+import cc.mallet.classify.ClassifierTrainer;
+import cc.mallet.classify.MaxEntTrainer;
+import cc.mallet.pipe.Pipe;
+import cc.mallet.pipe.SerialPipes;
+import cc.mallet.pipe.SvmLight2FeatureVectorAndLabel;
+import cc.mallet.pipe.Target2Label;
+import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Instance;
+import cc.mallet.types.InstanceList;
+import joshua.decoder.Decoder;
+
+public class MalletPredictor implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private String sourceWord = null;
+ private String examples = null;
+ private Classifier classifier = null;
+
+ public MalletPredictor(String word, String examples) {
+ this.sourceWord = word;
+ this.examples = examples;
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
+
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(pipes);
+ }
+
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, outcome, null, null);
+// SYSTEM.ERR.PRINTLN("PREDICT TARGETWORD = " + (STRING) INSTANCE.GETTARGET());
+// SYSTEM.ERR.PRINTLN("PREDICT FEATURES = " + (STRING) INSTANCE.GETDATA());
+
+ if (classifier == null)
+ train();
+
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+ return result;
+ }
+
+ public void train() {
+ Decoder.LOG(2, String.format("Word %s: training model from %d examples",
+ sourceWord, examples.split("\\n").length));
+
+ StringReader reader = new StringReader(examples);
+
+ // Constructs an instance with everything shoved into the data field
+ instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ classifier = trainer.train(instances);
+
+// Decoder.LOG(1, String.format("%s: Trained a model for %s with %d outcomes",
+// name, sourceWord, pipes.getTargetAlphabet().size()));
+ }
+
+ /**
+ * Returns the number of distinct outcomes. Requires the model to have been trained!
+ *
+ * @return
+ */
+ public int getNumOutcomes() {
+ if (classifier == null)
+ train();
+ return pipes.getTargetAlphabet().size();
+ }
+ }
\ No newline at end of file