You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:54 UTC
[08/13] incubator-joshua git commit: Rewrote in less efficient way,
now works!
Rewrote in less efficient way, now works!
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/b121fc20
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/b121fc20
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/b121fc20
Branch: refs/heads/morph
Commit: b121fc20c22dd6752ccf3e728ed8c23ca73eb9aa
Parents: 23306b4
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 17:59:50 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 17:59:50 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 267 ++++++++-----------
1 file changed, 110 insertions(+), 157 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b121fc20/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index a66f4a1..701a72f 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -55,42 +55,16 @@ import joshua.util.io.LineReader;
public class LexicalSharpener extends StatelessFF {
- private HashMap<Integer,Object> classifiers = null;
- private SerialPipes pipes = null;
- private InstanceList instances = null;
- private LabelAlphabet labelAlphabet = null;
- private Alphabet dataAlphabet = null;
-
+ private HashMap<Integer,Predictor> classifiers = null;
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
-
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- dataAlphabet = new Alphabet();
- labelAlphabet = new LabelAlphabet();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
- // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
- pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
if (parsedArgs.containsKey("model")) {
String modelFile = parsedArgs.get("model");
if (! new File(modelFile).exists()) {
if (parsedArgs.getOrDefault("training-data", null) != null) {
try {
- train(parsedArgs.get("training-data"));
+ trainAll(parsedArgs.get("training-data"));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -113,49 +87,9 @@ public class LexicalSharpener extends StatelessFF {
}
}
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public void train(String dataFile) throws FileNotFoundException {
-
- classifiers = new HashMap<Integer, Object>();
-
- Decoder.VERBOSE = 1;
- LineReader lineReader = null;
- try {
- lineReader = new LineReader(dataFile, true);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
-
- String lastOutcome = null;
- String buffer = "";
- int linesRead = 0;
- for (String line : lineReader) {
- String outcome = line.substring(0, line.indexOf(' '));
- if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
- classifiers.put(Vocabulary.id(lastOutcome), buffer);
-// System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
- buffer = "";
- }
-
- buffer += line + "\n";
- lastOutcome = outcome;
- linesRead++;
- }
- classifiers.put(Vocabulary.id(lastOutcome), buffer);
-
- System.err.println(String.format("Read %d lines from training file", linesRead));
- }
-
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifiers = (HashMap<Integer,Object>) ois.readObject();
+ classifiers = (HashMap<Integer,Predictor>) ois.readObject();
ois.close();
System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
@@ -171,53 +105,6 @@ public class LexicalSharpener extends StatelessFF {
}
/**
- * Returns a Classification object a list of features. Uses "which" to determine which classifier
- * to use.
- *
- * @param which the classifier to use
- * @param features the set of features
- * @return
- */
- public Classification predict(String which, String features) {
- Instance instance = new Instance(features, which, which, which);
-// System.err.println("PREDICT outcome = " + (String) instance.getTarget());
-// System.err.println("PREDICT features = " + (String) instance.getData());
-
- Classifier classifier = getClassifier(which);
- if (classifier != null) {
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
- return result;
- }
-
- return null;
- }
-
- public Classifier getClassifier(String which) {
- int id = Vocabulary.id(which);
-
- if (classifiers.containsKey(id)) {
- if (classifiers.get(id) instanceof String) {
- StringReader reader = new StringReader((String)classifiers.get(id));
-
- System.err.println("training new classifier from string for " + which);
- System.err.println(String.format("training string is: '%s'", (String)classifiers.get(id)));
-
- // Constructs an instance with everything shoved into the data field
- instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- Classifier classifier = trainer.train(instances);
-
- classifiers.put(id, classifier);
- }
-
- return (Classifier) classifiers.get(id);
- }
-
- return null;
- }
-
- /**
* Compute features. This works by walking over the target side phrase pieces, looking for every
* word with a single source-aligned word. We then throw the annotations from that source word
* into our prediction model to learn how much it likes the chosen word. Presumably the source-
@@ -238,7 +125,7 @@ public class LexicalSharpener extends StatelessFF {
Token sourceToken = sentence.getTokens().get(s);
String featureString = sourceToken.getAnnotationString().replace('|', ' ');
- Classification result = predict(targetWord, featureString);
+ Classification result = predict(sourceToken.getWord(), featureString);
if (result.bestLabelIsCorrect()) {
acc.add(String.format("%s_match", name), 1);
}
@@ -247,6 +134,17 @@ public class LexicalSharpener extends StatelessFF {
return null;
}
+ public Classification predict(int id, String featureString) {
+ String word = Vocabulary.word(id);
+ if (classifiers.containsKey(id)) {
+ Predictor predictor = classifiers.get(id);
+ if (predictor != null)
+ return predictor.predict(featureString);
+ }
+
+ return null;
+ }
+
/**
* Returns an array parallel to the source words array indicating, for each index, the absolute
* position of that word into the source sentence. For example, for the rule with source side
@@ -291,54 +189,109 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
- public static class CustomLineProcessor extends Pipe {
+ public class Predictor {
+
+ private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private String sourceWord = null;
+ private String examples = null;
+ private Classifier classifier = null;
+
+ public Predictor(String word, String examples) {
+ this.sourceWord = word;
+ this.examples = examples;
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
- private static final long serialVersionUID = 1L;
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(pipes);
+ }
+
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String features) {
+ Instance instance = new Instance(features, null, null, null);
+ // System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
+ // System.err.println("PREDICT features = " + (String) instance.getData());
+
+ if (classifier == null)
+ train();
- public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
- super(data, label);
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+ return result;
}
-
- @Override
- public Instance pipe(Instance carrier) {
- // we expect the data for each instance to be
- // a line from the SVMLight format text file
- String dataStr = (String)carrier.getData();
- String[] tokens = dataStr.split("\\s+");
+ public void train() {
+ System.err.println(String.format("Word %s: training model", sourceWord));
+ System.err.println(String.format(" Examples: %s", examples));
+
+ StringReader reader = new StringReader(examples);
-// carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+ // Constructs an instance with everything shoved into the data field
+ instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
- // the rest are feature-value pairs
- ArrayList<Integer> indices = new ArrayList<Integer>();
- ArrayList<Double> values = new ArrayList<Double>();
- for (int termIndex = 0; termIndex < tokens.length; termIndex++) {
- if (!tokens[termIndex].equals("")) {
- String[] s = tokens[termIndex].split(":");
- if (s.length != 2) {
- throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
- }
- String feature = s[0];
- int index = getDataAlphabet().lookupIndex(feature, true);
- indices.add(index);
- values.add(Double.parseDouble(s[1]));
- }
- }
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ classifier = trainer.train(instances);
+ }
+ }
+
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public void trainAll(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Predictor>();
- assert(indices.size() == values.size());
- int[] indicesArr = new int[indices.size()];
- double[] valuesArr = new double[values.size()];
- for (int i = 0; i < indicesArr.length; i++) {
- indicesArr[i] = indices.get(i);
- valuesArr[i] = values.get(i);
+ Decoder.VERBOSE = 1;
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastSourceWord = null;
+ String examples = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String sourceWord = line.substring(0, line.indexOf(' '));
+ if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ examples = "";
}
- cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
- carrier.setData(fv);
- return carrier;
+ examples += line + "\n";
+ lastSourceWord = sourceWord;
+ linesRead++;
}
- }
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ System.err.println(String.format("Read %d lines from training file", linesRead));
+ }
+
public static void example(String[] args) throws IOException, ClassNotFoundException {
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
@@ -348,7 +301,7 @@ public class LexicalSharpener extends StatelessFF {
pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
// Basically, SvmLight but with a custom (fixed) alphabet)
- pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
FileReader reader1 = new FileReader("data.1");
FileReader reader2 = new FileReader("data.2");
@@ -377,7 +330,7 @@ public class LexicalSharpener extends StatelessFF {
String dataFile = args[0];
System.err.println("Training model from file " + dataFile);
- ts.train(dataFile);
+ ts.trainAll(dataFile);
// if (args.length > 1)
// modelFile = args[1];
@@ -395,7 +348,7 @@ public class LexicalSharpener extends StatelessFF {
String[] tokens = line.split(" ", 2);
String sourceWord = tokens[0];
String features = tokens[1];
- Classification result = ts.predict(sourceWord, features);
+ Classification result = ts.predict(Vocabulary.id(sourceWord), features);
if (result != null)
System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
else