You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:51 UTC
[05/13] incubator-joshua git commit: Version with custom line
processor; still not working (alphabet problems)
Version with custom line processor; still not working (alphabet problems)
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/d17d08af
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/d17d08af
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/d17d08af
Branch: refs/heads/morph
Commit: d17d08afb449ce86c5f0d9ca5b2ccdc5ea777fda
Parents: 0ef5d3e
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 14:11:37 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 14:11:37 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 192 ++++++++++++++++---
1 file changed, 164 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d17d08af/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index edf4390..b4b455d 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,5 +1,7 @@
package joshua.decoder.ff.morph;
+import java.io.BufferedReader;
+
/***
* This feature function scores a rule application by predicting, for each target word aligned with
* a source word, how likely the lexical translation is in context.
@@ -20,11 +22,12 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
-import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.io.StringReader;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
@@ -32,9 +35,12 @@ import java.util.Scanner;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
+import cc.mallet.types.LabelAlphabet;
import joshua.corpus.Vocabulary;
+import joshua.decoder.Decoder;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.FeatureVector;
@@ -44,21 +50,28 @@ import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.segment_file.Sentence;
import joshua.decoder.segment_file.Token;
+import joshua.util.io.LineReader;
public class LexicalSharpener extends StatelessFF {
- private Classifier classifier = null;
+ private HashMap<Integer,Object> classifiers = null;
private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private LabelAlphabet labelAlphabet = null;
+ private Alphabet dataAlphabet = null;
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+ dataAlphabet = new Alphabet();
+ labelAlphabet = new LabelAlphabet();
+
// I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert SVM-light format to sparse feature vector
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
+// pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
// Validation
// pipeList.add(new PrintInputAndTarget());
@@ -68,13 +81,15 @@ public class LexicalSharpener extends StatelessFF {
// source: null
pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
if (parsedArgs.containsKey("model")) {
String modelFile = parsedArgs.get("model");
if (! new File(modelFile).exists()) {
if (parsedArgs.getOrDefault("training-data", null) != null) {
try {
- classifier = train(parsedArgs.get("training-data"));
+ train(parsedArgs.get("training-data"));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -85,7 +100,7 @@ public class LexicalSharpener extends StatelessFF {
}
} else {
try {
- loadClassifier(modelFile);
+ loadClassifiers(modelFile);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -104,39 +119,103 @@ public class LexicalSharpener extends StatelessFF {
* @return
* @throws FileNotFoundException
*/
- public Classifier train(String dataFile) throws FileNotFoundException {
-
- // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
- InstanceList instances = new InstanceList(pipes);
- instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
- "(\\w+)\\s+(.*)",
- 2, -1, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- Classifier classifier = trainer.train(instances);
+ public void train(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Object>();
- return classifier;
+ Decoder.VERBOSE = 1;
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastOutcome = null;
+ String buffer = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String outcome = line.substring(0, line.indexOf(' '));
+ if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
+ classifiers.put(Vocabulary.id(lastOutcome), buffer);
+// System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ buffer = "";
+ }
+
+ buffer += line + "\n";
+ lastOutcome = outcome;
+ linesRead++;
+ }
+ classifiers.put(Vocabulary.id(lastOutcome), buffer);
+
+ System.err.println(String.format("Read %d lines from training file", linesRead));
}
- public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+ public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifier = (Classifier) ois.readObject();
+ classifiers = (HashMap<Integer,Object>) ois.readObject();
}
- public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+ public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
- oos.writeObject(classifier);
+ oos.writeObject(classifiers);
oos.close();
}
- public Classification predict(String outcome, String features) {
- Instance instance = new Instance(features, null, null, null);
- System.err.println("PREDICT outcome = " + (String) instance.getTarget());
- System.err.println("PREDICT features = " + (String) instance.getData());
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String which, String features) {
+ Instance instance = new Instance(features, which, null, null);
+// System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+// System.err.println("PREDICT features = " + (String) instance.getData());
+
+ Classifier classifier = getClassifier(which);
Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
return result;
}
+
+ public Classifier getClassifier(String which) {
+ int id = Vocabulary.id(which);
+
+ if (classifiers.containsKey(id)) {
+ if (classifiers.get(id) instanceof String) {
+ StringReader reader = new StringReader((String)classifiers.get(id));
+
+ System.err.println("training new classifier from string for " + which);
+
+ BufferedReader t = new BufferedReader(reader);
+ String line;
+ try {
+ while ((line = t.readLine()) != null) {
+ System.out.println(" LINE: " + line);
+ }
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
+ instances.addThruPipe(new CsvIterator(reader, "(\\w+)\\s+(.*)", 2, -1, -1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ Classifier classifier = trainer.train(instances);
+
+ classifiers.put(id, classifier);
+ }
+
+ return (Classifier) classifiers.get(id);
+ }
+
+ return null;
+ }
/**
* Compute features. This works by walking over the target side phrase pieces, looking for every
@@ -211,6 +290,63 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
+
+ public class CustomLineProcessor extends Pipe {
+
+ private static final long serialVersionUID = 1L;
+
+ public CustomLineProcessor() {
+ super (new Alphabet(), new LabelAlphabet());
+ }
+
+ public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
+ super(data, label);
+ }
+
+ @Override
+ public Instance pipe(Instance carrier) {
+ // we expect the data for each instance to be
+ // a line from the SVMLight format text file
+ String dataStr = (String)carrier.getData();
+
+ String[] tokens = dataStr.split("\\s+");
+
+ carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+
+ // the rest are feature-value pairs
+ ArrayList<Integer> indices = new ArrayList<Integer>();
+ ArrayList<Double> values = new ArrayList<Double>();
+ for (int termIndex = 1; termIndex < tokens.length; termIndex++) {
+ if (!tokens[termIndex].equals("")) {
+ String[] s = tokens[termIndex].split(":");
+ if (s.length != 2) {
+ throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
+ }
+ String feature = s[0];
+ int index = getDataAlphabet().lookupIndex(feature, true);
+
+ // index may be -1 if growth of the
+ // data alphabet is stopped
+ if (index >= 0) {
+ indices.add(index);
+ values.add(Double.parseDouble(s[1]));
+ }
+ }
+ }
+
+ assert(indices.size() == values.size());
+ int[] indicesArr = new int[indices.size()];
+ double[] valuesArr = new double[values.size()];
+ for (int i = 0; i < indicesArr.length; i++) {
+ indicesArr[i] = indices.get(i);
+ valuesArr[i] = values.get(i);
+ }
+
+ cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
+ carrier.setData(fv);
+ return carrier;
+ }
+ }
public static void main(String[] args) throws IOException, ClassNotFoundException {
LexicalSharpener ts = new LexicalSharpener(null, args, null);
@@ -227,10 +363,10 @@ public class LexicalSharpener extends StatelessFF {
modelFile = args[1];
System.err.println("Writing model to file " + modelFile);
- ts.saveClassifier(modelFile);
+ ts.saveClassifiers(modelFile);
} else {
System.err.println("Loading model from file " + modelFile);
- ts.loadClassifier(modelFile);
+ ts.loadClassifiers(modelFile);
}
Scanner stdin = new Scanner(System.in);