You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:47 UTC
[01/13] incubator-joshua git commit: Added inflection predictor
training code (calls to mallet)
Repository: incubator-joshua
Updated Branches:
refs/heads/morph [created] a86ae8e87
Added inflection predictor training code (calls to mallet)
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/e3ad1a69
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/e3ad1a69
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/e3ad1a69
Branch: refs/heads/morph
Commit: e3ad1a6936422c685756ce4ab8c2b85dc2a449f2
Parents: 5396c5f
Author: Matt Post <po...@cs.jhu.edu>
Authored: Wed Apr 20 17:19:54 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Wed Apr 20 17:19:54 2016 -0400
----------------------------------------------------------------------
lib/ivy.xml | 2 +
.../decoder/ff/morph/InflectionPredictor.java | 98 ++++++++++++++++++++
2 files changed, 100 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e3ad1a69/lib/ivy.xml
----------------------------------------------------------------------
diff --git a/lib/ivy.xml b/lib/ivy.xml
index d41595d..dfa95ab 100644
--- a/lib/ivy.xml
+++ b/lib/ivy.xml
@@ -13,5 +13,7 @@
<dependency org="args4j" name="args4j" rev="2.0.29" />
<dependency org="com.google.code.gson" name="gson" rev="2.5"/>
<dependency org="com.google.guava" name="guava" rev="19.0"/>
+ <dependency org="cc.mallet" name="mallet" rev="2.0.7"/>
+ <dependency org="net.sf.trove4j" name="trove4j" rev="2.0.2"/>
</dependencies>
</ivy-module>
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/e3ad1a69/src/joshua/decoder/ff/morph/InflectionPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/InflectionPredictor.java b/src/joshua/decoder/ff/morph/InflectionPredictor.java
new file mode 100644
index 0000000..82d52ea
--- /dev/null
+++ b/src/joshua/decoder/ff/morph/InflectionPredictor.java
@@ -0,0 +1,98 @@
+package joshua.decoder.ff.morph;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.util.ArrayList;
+import java.util.List;
+
+import cc.mallet.classify.Classifier;
+import cc.mallet.classify.ClassifierTrainer;
+import cc.mallet.classify.MaxEntTrainer;
+import cc.mallet.classify.NaiveBayesTrainer;
+import cc.mallet.pipe.*;
+import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.pipe.iterator.FileIterator;
+import cc.mallet.pipe.iterator.LineIterator;
+import cc.mallet.types.InstanceList;
+import joshua.decoder.JoshuaConfiguration;
+import joshua.decoder.chart_parser.SourcePath;
+import joshua.decoder.ff.FeatureVector;
+import joshua.decoder.ff.StatelessFF;
+import joshua.decoder.ff.state_maintenance.DPState;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.segment_file.Sentence;
+
+public class InflectionPredictor extends StatelessFF {
+
+ private Classifier classifier = null;
+
+ public InflectionPredictor(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "InfectionPredictor", args, config);
+
+ if (parsedArgs.containsKey("model")) {
+ String modelFile = parsedArgs.get("model");
+ if (! new File(modelFile).exists()) {
+ if (parsedArgs.getOrDefault("training-data", null) != null) {
+ try {
+ classifier = train(parsedArgs.get("training-data"));
+ } catch (FileNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ } else {
+ System.err.println("* FATAL: no model and no training data.");
+ System.exit(1);
+ }
+ } else {
+ // TODO: load the model
+ }
+ }
+ }
+
+ public Classifier train(String dataFile) throws FileNotFoundException {
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert SVM-light format to sparse feature vector
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
+
+ // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
+ InstanceList instances = new InstanceList(new SerialPipes(pipeList));
+ instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
+ "(\\w+)\\s+(.*)",
+ 2, -1, 1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ Classifier classifier = trainer.train(instances);
+
+ return classifier;
+ }
+
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ return null;
+ }
+
+ public static void main(String[] args) throws FileNotFoundException {
+ InflectionPredictor ip = new InflectionPredictor(null, args, null);
+
+ String dataFile = "/Users/post/Desktop/amazon16/model";
+ if (args.length > 0)
+ dataFile = args[0];
+
+ ip.train(dataFile);
+ }
+
+}
[07/13] incubator-joshua git commit: added working example to ask for
help
Posted by mj...@apache.org.
added working example to ask for help
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/23306b4f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/23306b4f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/23306b4f
Branch: refs/heads/morph
Commit: 23306b4f020f1332817d9f89391a65d9debe8208
Parents: a055c3f
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 17:03:58 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 17:03:58 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 32 +++++++++++++++++++-
1 file changed, 31 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/23306b4f/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index fab71d9..a66f4a1 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -22,6 +22,7 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
+import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
@@ -290,7 +291,7 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
- public class CustomLineProcessor extends Pipe {
+ public static class CustomLineProcessor extends Pipe {
private static final long serialVersionUID = 1L;
@@ -338,6 +339,35 @@ public class LexicalSharpener extends StatelessFF {
}
}
+ public static void example(String[] args) throws IOException, ClassNotFoundException {
+
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ Alphabet dataAlphabet = new Alphabet();
+ LabelAlphabet labelAlphabet = new LabelAlphabet();
+
+ pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+ // Basically, SvmLight but with a custom (fixed) alphabet)
+ pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
+
+ FileReader reader1 = new FileReader("data.1");
+ FileReader reader2 = new FileReader("data.2");
+
+ SerialPipes pipes = new SerialPipes(pipeList);
+ InstanceList instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
+ instances.addThruPipe(new CsvIterator(reader1, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
+ ClassifierTrainer trainer1 = new MaxEntTrainer();
+ Classifier classifier1 = trainer1.train(instances);
+
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
+ instances.addThruPipe(new CsvIterator(reader2, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
+ ClassifierTrainer trainer2 = new MaxEntTrainer();
+ Classifier classifier2 = trainer2.train(instances);
+ }
+
public static void main(String[] args) throws IOException, ClassNotFoundException {
LexicalSharpener ts = new LexicalSharpener(null, args, null);
[06/13] incubator-joshua git commit: Will now train a single model,
barfs with Alphabet issues on second, ARGH
Posted by mj...@apache.org.
Will now train a single model, barfs with Alphabet issues on second, ARGH
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/a055c3f7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/a055c3f7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/a055c3f7
Branch: refs/heads/morph
Commit: a055c3f7edb84234ecfbdf56d4be92a60fc5fc0e
Parents: d17d08a
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 15:06:35 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 15:06:57 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 77 +++++++++-----------
1 file changed, 35 insertions(+), 42 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/a055c3f7/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index b4b455d..fab71d9 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -69,7 +69,7 @@ public class LexicalSharpener extends StatelessFF {
labelAlphabet = new LabelAlphabet();
// I don't know if this is needed
-// pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+ pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
// Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
// Validation
@@ -155,6 +155,12 @@ public class LexicalSharpener extends StatelessFF {
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
classifiers = (HashMap<Integer,Object>) ois.readObject();
+ ois.close();
+
+ System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
+ for (int key: classifiers.keySet()) {
+ System.err.println(" " + key);
+ }
}
public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
@@ -172,14 +178,17 @@ public class LexicalSharpener extends StatelessFF {
* @return
*/
public Classification predict(String which, String features) {
- Instance instance = new Instance(features, which, null, null);
+ Instance instance = new Instance(features, which, which, which);
// System.err.println("PREDICT outcome = " + (String) instance.getTarget());
// System.err.println("PREDICT features = " + (String) instance.getData());
Classifier classifier = getClassifier(which);
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
-
- return result;
+ if (classifier != null) {
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+ return result;
+ }
+
+ return null;
}
public Classifier getClassifier(String which) {
@@ -190,20 +199,10 @@ public class LexicalSharpener extends StatelessFF {
StringReader reader = new StringReader((String)classifiers.get(id));
System.err.println("training new classifier from string for " + which);
+ System.err.println(String.format("training string is: '%s'", (String)classifiers.get(id)));
- BufferedReader t = new BufferedReader(reader);
- String line;
- try {
- while ((line = t.readLine()) != null) {
- System.out.println(" LINE: " + line);
- }
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
-
- // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
- instances.addThruPipe(new CsvIterator(reader, "(\\w+)\\s+(.*)", 2, -1, -1));
+ // Constructs an instance with everything shoved into the data field
+ instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
ClassifierTrainer trainer = new MaxEntTrainer();
Classifier classifier = trainer.train(instances);
@@ -295,10 +294,6 @@ public class LexicalSharpener extends StatelessFF {
private static final long serialVersionUID = 1L;
- public CustomLineProcessor() {
- super (new Alphabet(), new LabelAlphabet());
- }
-
public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
super(data, label);
}
@@ -311,12 +306,12 @@ public class LexicalSharpener extends StatelessFF {
String[] tokens = dataStr.split("\\s+");
- carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+// carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
// the rest are feature-value pairs
ArrayList<Integer> indices = new ArrayList<Integer>();
ArrayList<Double> values = new ArrayList<Double>();
- for (int termIndex = 1; termIndex < tokens.length; termIndex++) {
+ for (int termIndex = 0; termIndex < tokens.length; termIndex++) {
if (!tokens[termIndex].equals("")) {
String[] s = tokens[termIndex].split(":");
if (s.length != 2) {
@@ -324,13 +319,8 @@ public class LexicalSharpener extends StatelessFF {
}
String feature = s[0];
int index = getDataAlphabet().lookupIndex(feature, true);
-
- // index may be -1 if growth of the
- // data alphabet is stopped
- if (index >= 0) {
- indices.add(index);
- values.add(Double.parseDouble(s[1]));
- }
+ indices.add(index);
+ values.add(Double.parseDouble(s[1]));
}
}
@@ -359,24 +349,27 @@ public class LexicalSharpener extends StatelessFF {
System.err.println("Training model from file " + dataFile);
ts.train(dataFile);
- if (args.length > 1)
- modelFile = args[1];
-
- System.err.println("Writing model to file " + modelFile);
- ts.saveClassifiers(modelFile);
- } else {
- System.err.println("Loading model from file " + modelFile);
- ts.loadClassifiers(modelFile);
+// if (args.length > 1)
+// modelFile = args[1];
+//
+// System.err.println("Writing model to file " + modelFile);
+// ts.saveClassifiers(modelFile);
+// } else {
+// System.err.println("Loading model from file " + modelFile);
+// ts.loadClassifiers(modelFile);
}
Scanner stdin = new Scanner(System.in);
while(stdin.hasNextLine()) {
String line = stdin.nextLine();
String[] tokens = line.split(" ", 2);
- String outcome = tokens[0];
+ String sourceWord = tokens[0];
String features = tokens[1];
- Classification result = ts.predict(outcome, features);
- System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+ Classification result = ts.predict(sourceWord, features);
+ if (result != null)
+ System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+ else
+ System.out.println("i got nothing");
}
}
}
[11/13] incubator-joshua git commit: fix in flush location
Posted by mj...@apache.org.
fix in flush location
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/61cd4896
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/61cd4896
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/61cd4896
Branch: refs/heads/morph
Commit: 61cd4896d1b604312c234c37c356180df3c1c707
Parents: 4ec7ddb
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:12:26 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:12:26 2016 -0400
----------------------------------------------------------------------
src/joshua/util/io/LineReader.java | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/61cd4896/src/joshua/util/io/LineReader.java
----------------------------------------------------------------------
diff --git a/src/joshua/util/io/LineReader.java b/src/joshua/util/io/LineReader.java
index a4f9fe0..b4ef38c 100644
--- a/src/joshua/util/io/LineReader.java
+++ b/src/joshua/util/io/LineReader.java
@@ -274,7 +274,7 @@ public class LineReader implements Reader<String> {
// System.err.println(String.format("OLD %d NEW %d", progress, newProgress));
if (newProgress > progress) {
- for (int i = progress + 1; i <= newProgress; i++)
+ for (int i = progress + 1; i <= newProgress; i++) {
if (i == 97) {
System.err.print("1");
} else if (i == 98) {
@@ -285,13 +285,13 @@ public class LineReader implements Reader<String> {
System.err.println("%");
} else if (i % 10 == 0) {
System.err.print(String.format("%d", i));
- System.err.flush();
} else if ((i - 1) % 10 == 0)
; // skip at 11 since 10, 20, etc take two digits
else {
System.err.print(".");
- System.err.flush();
}
+ }
+ System.err.flush();
progress = newProgress;
}
}
[04/13] incubator-joshua git commit: renamed
Posted by mj...@apache.org.
renamed
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/0ef5d3eb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/0ef5d3eb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/0ef5d3eb
Branch: refs/heads/morph
Commit: 0ef5d3eb25f3abd13d60b69dcb290a65c8214c73
Parents: 47f1af5
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 09:15:42 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 09:15:42 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/InflectionPredictor.java | 246 -------------------
.../decoder/ff/morph/LexicalSharpener.java | 246 +++++++++++++++++++
2 files changed, 246 insertions(+), 246 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0ef5d3eb/src/joshua/decoder/ff/morph/InflectionPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/InflectionPredictor.java b/src/joshua/decoder/ff/morph/InflectionPredictor.java
deleted file mode 100644
index f4a4310..0000000
--- a/src/joshua/decoder/ff/morph/InflectionPredictor.java
+++ /dev/null
@@ -1,246 +0,0 @@
-package joshua.decoder.ff.morph;
-
-/***
- * This feature function scores a rule application by predicting, for each target word aligned with
- * a source word, how likely the lexical translation is in context.
- *
- * The feature function can be provided with a trained model or a raw training file which it will
- * then train prior to decoding.
- *
- * Format of training file:
- *
- * source_word target_word feature:value feature:value feature:value ...
- *
- * Invocation:
- *
- * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data
- */
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
-import java.io.FileReader;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Scanner;
-
-import cc.mallet.classify.*;
-import cc.mallet.pipe.*;
-import cc.mallet.pipe.iterator.CsvIterator;
-import cc.mallet.types.Instance;
-import cc.mallet.types.InstanceList;
-import joshua.corpus.Vocabulary;
-import joshua.decoder.JoshuaConfiguration;
-import joshua.decoder.chart_parser.SourcePath;
-import joshua.decoder.ff.FeatureVector;
-import joshua.decoder.ff.StatelessFF;
-import joshua.decoder.ff.state_maintenance.DPState;
-import joshua.decoder.ff.tm.Rule;
-import joshua.decoder.hypergraph.HGNode;
-import joshua.decoder.segment_file.Sentence;
-import joshua.decoder.segment_file.Token;
-
-public class InflectionPredictor extends StatelessFF {
-
- private Classifier classifier = null;
- private SerialPipes pipes = null;
-
- public InflectionPredictor(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
- super(weights, "LexicalSharpener", args, config);
-
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert SVM-light format to sparse feature vector
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
-
- pipes = new SerialPipes(pipeList);
-
- if (parsedArgs.containsKey("model")) {
- String modelFile = parsedArgs.get("model");
- if (! new File(modelFile).exists()) {
- if (parsedArgs.getOrDefault("training-data", null) != null) {
- try {
- classifier = train(parsedArgs.get("training-data"));
- } catch (FileNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- } else {
- System.err.println("* FATAL: no model and no training data.");
- System.exit(1);
- }
- } else {
- try {
- loadClassifier(modelFile);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
- }
- }
-
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public Classifier train(String dataFile) throws FileNotFoundException {
-
- // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
- InstanceList instances = new InstanceList(pipes);
- instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
- "(\\w+)\\s+(.*)",
- 2, -1, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- Classifier classifier = trainer.train(instances);
-
- return classifier;
- }
-
- public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
- ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifier = (Classifier) ois.readObject();
- }
-
- public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
- ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
- oos.writeObject(classifier);
- oos.close();
- }
-
- public Classification predict(String outcome, String features) {
- Instance instance = new Instance(features, null, null, null);
- System.err.println("PREDICT outcome = " + (String) instance.getTarget());
- System.err.println("PREDICT features = " + (String) instance.getData());
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
-
- return result;
- }
-
- /**
- * Compute features. This works by walking over the target side phrase pieces, looking for every
- * word with a single source-aligned word. We then throw the annotations from that source word
- * into our prediction model to learn how much it likes the chosen word. Presumably the source-
- * language annotations have contextual features, so this effectively chooses the words in context.
- */
- @Override
- public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
- Sentence sentence, Accumulator acc) {
-
- Map<Integer, List<Integer>> points = rule.getAlignmentMap();
- for (int t: points.keySet()) {
- List<Integer> source_indices = points.get(t);
- if (source_indices.size() != 1)
- continue;
-
- String targetWord = Vocabulary.word(rule.getEnglish()[t]);
- int s = i + source_indices.get(0);
- Token sourceToken = sentence.getTokens().get(s);
- String featureString = sourceToken.getAnnotationString().replace('|', ' ');
-
- Classification result = predict(targetWord, featureString);
- if (result.bestLabelIsCorrect()) {
- acc.add(String.format("%s_match", name), 1);
- }
- }
-
- return null;
- }
-
- /**
- * Returns an array parallel to the source words array indicating, for each index, the absolute
- * position of that word into the source sentence. For example, for the rule with source side
- *
- * [ 17, 142, -14, 9 ]
- *
- * and source sentence
- *
- * [ 17, 18, 142, 1, 1, 9, 8 ]
- *
- * it will return
- *
- * [ 0, 2, -14, 5 ]
- *
- * which indicates that the first, second, and fourth words of the rule are anchored to the
- * first, third, and sixth words of the input sentence.
- *
- * @param rule
- * @param tailNodes
- * @param start
- * @return a list of alignment points anchored to the source sentence
- */
- public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
- int[] source = rule.getFrench();
-
- // Map the source words in the rule to absolute positions in the sentence
- int[] anchoredSource = source.clone();
-
- int sourceIndex = start;
- int tailNodeIndex = 0;
- for (int i = 0; i < source.length; i++) {
- if (source[i] < 0) { // nonterminal
- anchoredSource[i] = source[i];
- sourceIndex = tailNodes.get(tailNodeIndex).j;
- tailNodeIndex++;
- } else { // terminal
- anchoredSource[i] = sourceIndex;
- sourceIndex++;
- }
- }
-
- return anchoredSource;
- }
-
- public static void main(String[] args) throws IOException, ClassNotFoundException {
- InflectionPredictor ts = new InflectionPredictor(null, args, null);
-
- String modelFile = "model";
-
- if (args.length > 0) {
- String dataFile = args[0];
-
- System.err.println("Training model from file " + dataFile);
- ts.train(dataFile);
-
- if (args.length > 1)
- modelFile = args[1];
-
- System.err.println("Writing model to file " + modelFile);
- ts.saveClassifier(modelFile);
- } else {
- System.err.println("Loading model from file " + modelFile);
- ts.loadClassifier(modelFile);
- }
-
- Scanner stdin = new Scanner(System.in);
- while(stdin.hasNextLine()) {
- String line = stdin.nextLine();
- String[] tokens = line.split(" ", 2);
- String outcome = tokens[0];
- String features = tokens[1];
- Classification result = ts.predict(outcome, features);
- System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/0ef5d3eb/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
new file mode 100644
index 0000000..edf4390
--- /dev/null
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -0,0 +1,246 @@
+package joshua.decoder.ff.morph;
+
+/***
+ * This feature function scores a rule application by predicting, for each target word aligned with
+ * a source word, how likely the lexical translation is in context.
+ *
+ * The feature function can be provided with a trained model or a raw training file which it will
+ * then train prior to decoding.
+ *
+ * Format of training file:
+ *
+ * source_word target_word feature:value feature:value feature:value ...
+ *
+ * Invocation:
+ *
+ * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data
+ */
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
+
+import cc.mallet.classify.*;
+import cc.mallet.pipe.*;
+import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Instance;
+import cc.mallet.types.InstanceList;
+import joshua.corpus.Vocabulary;
+import joshua.decoder.JoshuaConfiguration;
+import joshua.decoder.chart_parser.SourcePath;
+import joshua.decoder.ff.FeatureVector;
+import joshua.decoder.ff.StatelessFF;
+import joshua.decoder.ff.state_maintenance.DPState;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.segment_file.Sentence;
+import joshua.decoder.segment_file.Token;
+
+public class LexicalSharpener extends StatelessFF {
+
+ private Classifier classifier = null;
+ private SerialPipes pipes = null;
+
+ public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "LexicalSharpener", args, config);
+
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert SVM-light format to sparse feature vector
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
+
+ pipes = new SerialPipes(pipeList);
+
+ if (parsedArgs.containsKey("model")) {
+ String modelFile = parsedArgs.get("model");
+ if (! new File(modelFile).exists()) {
+ if (parsedArgs.getOrDefault("training-data", null) != null) {
+ try {
+ classifier = train(parsedArgs.get("training-data"));
+ } catch (FileNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ } else {
+ System.err.println("* FATAL: no model and no training data.");
+ System.exit(1);
+ }
+ } else {
+ try {
+ loadClassifier(modelFile);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+ }
+ }
+
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public Classifier train(String dataFile) throws FileNotFoundException {
+
+ // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
+ InstanceList instances = new InstanceList(pipes);
+ instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
+ "(\\w+)\\s+(.*)",
+ 2, -1, 1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ Classifier classifier = trainer.train(instances);
+
+ return classifier;
+ }
+
+ public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+ ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
+ classifier = (Classifier) ois.readObject();
+ }
+
+ public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+ ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
+ oos.writeObject(classifier);
+ oos.close();
+ }
+
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, null, null, null);
+ System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+ System.err.println("PREDICT features = " + (String) instance.getData());
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+
+ return result;
+ }
+
+ /**
+ * Compute features. This works by walking over the target side phrase pieces, looking for every
+ * word with a single source-aligned word. We then throw the annotations from that source word
+ * into our prediction model to learn how much it likes the chosen word. Presumably the source-
+ * language annotations have contextual features, so this effectively chooses the words in context.
+ */
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ Map<Integer, List<Integer>> points = rule.getAlignmentMap();
+ for (int t: points.keySet()) {
+ List<Integer> source_indices = points.get(t);
+ if (source_indices.size() != 1)
+ continue;
+
+ String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+ int s = i + source_indices.get(0);
+ Token sourceToken = sentence.getTokens().get(s);
+ String featureString = sourceToken.getAnnotationString().replace('|', ' ');
+
+ Classification result = predict(targetWord, featureString);
+ if (result.bestLabelIsCorrect()) {
+ acc.add(String.format("%s_match", name), 1);
+ }
+ }
+
+ return null;
+ }
+
+ /**
+ * Returns an array parallel to the source words array indicating, for each index, the absolute
+ * position of that word into the source sentence. For example, for the rule with source side
+ *
+ * [ 17, 142, -14, 9 ]
+ *
+ * and source sentence
+ *
+ * [ 17, 18, 142, 1, 1, 9, 8 ]
+ *
+ * it will return
+ *
+ * [ 0, 2, -14, 5 ]
+ *
+ * which indicates that the first, second, and fourth words of the rule are anchored to the
+ * first, third, and sixth words of the input sentence.
+ *
+ * @param rule
+ * @param tailNodes
+ * @param start
+ * @return a list of alignment points anchored to the source sentence
+ */
+ public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
+ int[] source = rule.getFrench();
+
+ // Map the source words in the rule to absolute positions in the sentence
+ int[] anchoredSource = source.clone();
+
+ int sourceIndex = start;
+ int tailNodeIndex = 0;
+ for (int i = 0; i < source.length; i++) {
+ if (source[i] < 0) { // nonterminal
+ anchoredSource[i] = source[i];
+ sourceIndex = tailNodes.get(tailNodeIndex).j;
+ tailNodeIndex++;
+ } else { // terminal
+ anchoredSource[i] = sourceIndex;
+ sourceIndex++;
+ }
+ }
+
+ return anchoredSource;
+ }
+
+ public static void main(String[] args) throws IOException, ClassNotFoundException {
+ LexicalSharpener ts = new LexicalSharpener(null, args, null);
+
+ String modelFile = "model";
+
+ if (args.length > 0) {
+ String dataFile = args[0];
+
+ System.err.println("Training model from file " + dataFile);
+ ts.train(dataFile);
+
+ if (args.length > 1)
+ modelFile = args[1];
+
+ System.err.println("Writing model to file " + modelFile);
+ ts.saveClassifier(modelFile);
+ } else {
+ System.err.println("Loading model from file " + modelFile);
+ ts.loadClassifier(modelFile);
+ }
+
+ Scanner stdin = new Scanner(System.in);
+ while(stdin.hasNextLine()) {
+ String line = stdin.nextLine();
+ String[] tokens = line.split(" ", 2);
+ String outcome = tokens[0];
+ String features = tokens[1];
+ Classification result = ts.predict(outcome, features);
+ System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+ }
+ }
+}
[05/13] incubator-joshua git commit: Version with custom line
processor; still not working (alphabet problems)
Posted by mj...@apache.org.
Version with custom line processor; still not working (alphabet problems)
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/d17d08af
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/d17d08af
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/d17d08af
Branch: refs/heads/morph
Commit: d17d08afb449ce86c5f0d9ca5b2ccdc5ea777fda
Parents: 0ef5d3e
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 14:11:37 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 14:11:37 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 192 ++++++++++++++++---
1 file changed, 164 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d17d08af/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index edf4390..b4b455d 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,5 +1,7 @@
package joshua.decoder.ff.morph;
+import java.io.BufferedReader;
+
/***
* This feature function scores a rule application by predicting, for each target word aligned with
* a source word, how likely the lexical translation is in context.
@@ -20,11 +22,12 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
-import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.io.StringReader;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
@@ -32,9 +35,12 @@ import java.util.Scanner;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
+import cc.mallet.types.LabelAlphabet;
import joshua.corpus.Vocabulary;
+import joshua.decoder.Decoder;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.FeatureVector;
@@ -44,21 +50,28 @@ import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.segment_file.Sentence;
import joshua.decoder.segment_file.Token;
+import joshua.util.io.LineReader;
public class LexicalSharpener extends StatelessFF {
- private Classifier classifier = null;
+ private HashMap<Integer,Object> classifiers = null;
private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private LabelAlphabet labelAlphabet = null;
+ private Alphabet dataAlphabet = null;
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+ dataAlphabet = new Alphabet();
+ labelAlphabet = new LabelAlphabet();
+
// I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert SVM-light format to sparse feature vector
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
+// pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
// Validation
// pipeList.add(new PrintInputAndTarget());
@@ -68,13 +81,15 @@ public class LexicalSharpener extends StatelessFF {
// source: null
pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
if (parsedArgs.containsKey("model")) {
String modelFile = parsedArgs.get("model");
if (! new File(modelFile).exists()) {
if (parsedArgs.getOrDefault("training-data", null) != null) {
try {
- classifier = train(parsedArgs.get("training-data"));
+ train(parsedArgs.get("training-data"));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -85,7 +100,7 @@ public class LexicalSharpener extends StatelessFF {
}
} else {
try {
- loadClassifier(modelFile);
+ loadClassifiers(modelFile);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -104,39 +119,103 @@ public class LexicalSharpener extends StatelessFF {
* @return
* @throws FileNotFoundException
*/
- public Classifier train(String dataFile) throws FileNotFoundException {
-
- // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
- InstanceList instances = new InstanceList(pipes);
- instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
- "(\\w+)\\s+(.*)",
- 2, -1, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- Classifier classifier = trainer.train(instances);
+ public void train(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Object>();
- return classifier;
+ Decoder.VERBOSE = 1;
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastOutcome = null;
+ String buffer = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String outcome = line.substring(0, line.indexOf(' '));
+ if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
+ classifiers.put(Vocabulary.id(lastOutcome), buffer);
+// System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ buffer = "";
+ }
+
+ buffer += line + "\n";
+ lastOutcome = outcome;
+ linesRead++;
+ }
+ classifiers.put(Vocabulary.id(lastOutcome), buffer);
+
+ System.err.println(String.format("Read %d lines from training file", linesRead));
}
- public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+ public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifier = (Classifier) ois.readObject();
+ classifiers = (HashMap<Integer,Object>) ois.readObject();
}
- public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+ public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
- oos.writeObject(classifier);
+ oos.writeObject(classifiers);
oos.close();
}
- public Classification predict(String outcome, String features) {
- Instance instance = new Instance(features, null, null, null);
- System.err.println("PREDICT outcome = " + (String) instance.getTarget());
- System.err.println("PREDICT features = " + (String) instance.getData());
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String which, String features) {
+ Instance instance = new Instance(features, which, null, null);
+// System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+// System.err.println("PREDICT features = " + (String) instance.getData());
+
+ Classifier classifier = getClassifier(which);
Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
return result;
}
+
+ public Classifier getClassifier(String which) {
+ int id = Vocabulary.id(which);
+
+ if (classifiers.containsKey(id)) {
+ if (classifiers.get(id) instanceof String) {
+ StringReader reader = new StringReader((String)classifiers.get(id));
+
+ System.err.println("training new classifier from string for " + which);
+
+ BufferedReader t = new BufferedReader(reader);
+ String line;
+ try {
+ while ((line = t.readLine()) != null) {
+ System.out.println(" LINE: " + line);
+ }
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
+ instances.addThruPipe(new CsvIterator(reader, "(\\w+)\\s+(.*)", 2, -1, -1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ Classifier classifier = trainer.train(instances);
+
+ classifiers.put(id, classifier);
+ }
+
+ return (Classifier) classifiers.get(id);
+ }
+
+ return null;
+ }
/**
* Compute features. This works by walking over the target side phrase pieces, looking for every
@@ -211,6 +290,63 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
+
+ public class CustomLineProcessor extends Pipe {
+
+ private static final long serialVersionUID = 1L;
+
+ public CustomLineProcessor() {
+ super (new Alphabet(), new LabelAlphabet());
+ }
+
+ public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
+ super(data, label);
+ }
+
+ @Override
+ public Instance pipe(Instance carrier) {
+ // we expect the data for each instance to be
+ // a line from the SVMLight format text file
+ String dataStr = (String)carrier.getData();
+
+ String[] tokens = dataStr.split("\\s+");
+
+ carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+
+ // the rest are feature-value pairs
+ ArrayList<Integer> indices = new ArrayList<Integer>();
+ ArrayList<Double> values = new ArrayList<Double>();
+ for (int termIndex = 1; termIndex < tokens.length; termIndex++) {
+ if (!tokens[termIndex].equals("")) {
+ String[] s = tokens[termIndex].split(":");
+ if (s.length != 2) {
+ throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
+ }
+ String feature = s[0];
+ int index = getDataAlphabet().lookupIndex(feature, true);
+
+ // index may be -1 if growth of the
+ // data alphabet is stopped
+ if (index >= 0) {
+ indices.add(index);
+ values.add(Double.parseDouble(s[1]));
+ }
+ }
+ }
+
+ assert(indices.size() == values.size());
+ int[] indicesArr = new int[indices.size()];
+ double[] valuesArr = new double[values.size()];
+ for (int i = 0; i < indicesArr.length; i++) {
+ indicesArr[i] = indices.get(i);
+ valuesArr[i] = values.get(i);
+ }
+
+ cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
+ carrier.setData(fv);
+ return carrier;
+ }
+ }
public static void main(String[] args) throws IOException, ClassNotFoundException {
LexicalSharpener ts = new LexicalSharpener(null, args, null);
@@ -227,10 +363,10 @@ public class LexicalSharpener extends StatelessFF {
modelFile = args[1];
System.err.println("Writing model to file " + modelFile);
- ts.saveClassifier(modelFile);
+ ts.saveClassifiers(modelFile);
} else {
System.err.println("Loading model from file " + modelFile);
- ts.loadClassifier(modelFile);
+ ts.loadClassifiers(modelFile);
}
Scanner stdin = new Scanner(System.in);
[13/13] incubator-joshua git commit: File move
Posted by mj...@apache.org.
File move
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/a86ae8e8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/a86ae8e8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/a86ae8e8
Branch: refs/heads/morph
Commit: a86ae8e87d9c6fcaae046aaebbf64634ba4bf5fc
Parents: 95ddf09
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:13:29 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:13:29 2016 -0400
----------------------------------------------------------------------
src/joshua/decoder/ff/LexicalSharpener.java | 356 ++++++++++++++++++
.../decoder/ff/morph/LexicalSharpener.java | 357 -------------------
2 files changed, 356 insertions(+), 357 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/a86ae8e8/src/joshua/decoder/ff/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/LexicalSharpener.java b/src/joshua/decoder/ff/LexicalSharpener.java
new file mode 100644
index 0000000..2c96f83
--- /dev/null
+++ b/src/joshua/decoder/ff/LexicalSharpener.java
@@ -0,0 +1,356 @@
+package joshua.decoder.ff;
+
+/***
+ * This feature function scores a rule application by predicting, for each target word aligned with
+ * a source word, how likely the lexical translation is in context.
+ *
+ * The feature function can be provided with a trained model or a raw training file which it will
+ * then train prior to decoding.
+ *
+ * Format of training file:
+ *
+ * source_word target_word feature:value feature:value feature:value ...
+ *
+ * Invocation:
+ *
+ * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data
+ */
+
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.StringReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
+
+import cc.mallet.classify.*;
+import cc.mallet.pipe.*;
+import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Alphabet;
+import cc.mallet.types.Instance;
+import cc.mallet.types.InstanceList;
+import cc.mallet.types.LabelAlphabet;
+import joshua.corpus.Vocabulary;
+import joshua.decoder.Decoder;
+import joshua.decoder.JoshuaConfiguration;
+import joshua.decoder.chart_parser.SourcePath;
+import joshua.decoder.ff.FeatureVector;
+import joshua.decoder.ff.StatelessFF;
+import joshua.decoder.ff.state_maintenance.DPState;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.hypergraph.HGNode;
+import joshua.decoder.segment_file.Sentence;
+import joshua.decoder.segment_file.Token;
+import joshua.util.io.LineReader;
+
+public class LexicalSharpener extends StatelessFF {
+
+ private HashMap<Integer,Predictor> classifiers = null;
+ public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
+ super(weights, "LexicalSharpener", args, config);
+
+ if (parsedArgs.getOrDefault("training-data", null) != null) {
+ try {
+ trainAll(parsedArgs.get("training-data"));
+ } catch (FileNotFoundException e) {
+ System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
+ System.exit(1);
+ }
+ }
+ }
+
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public void trainAll(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Predictor>();
+
+ Decoder.LOG(1, "Reading " + dataFile);
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastSourceWord = null;
+ String examples = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String sourceWord = line.substring(0, line.indexOf(' '));
+ if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ examples = "";
+ }
+
+ examples += line + "\n";
+ lastSourceWord = sourceWord;
+ linesRead++;
+ }
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+
+ System.err.println(String.format("Read %d lines from training file", linesRead));
+ }
+
+ public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
+ ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
+ classifiers = (HashMap<Integer,Predictor>) ois.readObject();
+ ois.close();
+
+ System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
+ for (int key: classifiers.keySet()) {
+ System.err.println(" " + key);
+ }
+ }
+
+ public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
+ ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
+ oos.writeObject(classifiers);
+ oos.close();
+ }
+
+ /**
+ * Compute features. This works by walking over the target side phrase pieces, looking for every
+ * word with a single source-aligned word. We then throw the annotations from that source word
+ * into our prediction model to learn how much it likes the chosen word. Presumably the source-
+ * language annotations have contextual features, so this effectively chooses the words in context.
+ */
+ @Override
+ public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
+ Sentence sentence, Accumulator acc) {
+
+ System.err.println(String.format("RULE: %s", rule));
+
+ Map<Integer, List<Integer>> points = rule.getAlignmentMap();
+ for (int t: points.keySet()) {
+ List<Integer> source_indices = points.get(t);
+ if (source_indices.size() != 1)
+ continue;
+
+ int targetID = rule.getEnglish()[t];
+ String targetWord = Vocabulary.word(targetID);
+ int s = i + source_indices.get(0);
+ Token sourceToken = sentence.getTokens().get(s);
+ String featureString = sourceToken.getAnnotationString().replace('|', ' ');
+
+ Classification result = predict(sourceToken.getWord(), targetID, featureString);
+ System.out.println("RESULT: " + result.getLabeling());
+ if (result.bestLabelIsCorrect()) {
+ acc.add(String.format("%s_match", name), 1);
+ }
+ }
+
+ return null;
+ }
+
+ public Classification predict(int sourceID, int targetID, String featureString) {
+ String word = Vocabulary.word(sourceID);
+ if (classifiers.containsKey(sourceID)) {
+ Predictor predictor = classifiers.get(sourceID);
+ if (predictor != null)
+ return predictor.predict(Vocabulary.word(targetID), featureString);
+ }
+
+ return null;
+ }
+
+ /**
+ * Returns an array parallel to the source words array indicating, for each index, the absolute
+ * position of that word into the source sentence. For example, for the rule with source side
+ *
+ * [ 17, 142, -14, 9 ]
+ *
+ * and source sentence
+ *
+ * [ 17, 18, 142, 1, 1, 9, 8 ]
+ *
+ * it will return
+ *
+ * [ 0, 2, -14, 5 ]
+ *
+ * which indicates that the first, second, and fourth words of the rule are anchored to the
+ * first, third, and sixth words of the input sentence.
+ *
+ * @param rule
+ * @param tailNodes
+ * @param start
+ * @return a list of alignment points anchored to the source sentence
+ */
+ public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
+ int[] source = rule.getFrench();
+
+ // Map the source words in the rule to absolute positions in the sentence
+ int[] anchoredSource = source.clone();
+
+ int sourceIndex = start;
+ int tailNodeIndex = 0;
+ for (int i = 0; i < source.length; i++) {
+ if (source[i] < 0) { // nonterminal
+ anchoredSource[i] = source[i];
+ sourceIndex = tailNodes.get(tailNodeIndex).j;
+ tailNodeIndex++;
+ } else { // terminal
+ anchoredSource[i] = sourceIndex;
+ sourceIndex++;
+ }
+ }
+
+ return anchoredSource;
+ }
+
+ public class Predictor {
+
+ private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private String sourceWord = null;
+ private String examples = null;
+ private Classifier classifier = null;
+
+ public Predictor(String word, String examples) {
+ this.sourceWord = word;
+ this.examples = examples;
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
+
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(pipes);
+ }
+
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, outcome, null, null);
+ System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
+ System.err.println("PREDICT features = " + (String) instance.getData());
+
+ if (classifier == null)
+ train();
+
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+ return result;
+ }
+
+ public void train() {
+// System.err.println(String.format("Word %s: training model", sourceWord));
+// System.err.println(String.format(" Examples: %s", examples));
+
+ StringReader reader = new StringReader(examples);
+
+ // Constructs an instance with everything shoved into the data field
+ instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
+
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ classifier = trainer.train(instances);
+
+ System.err.println(String.format("Trained a model for %s with %d outcomes",
+ sourceWord, pipes.getTargetAlphabet().size()));
+ }
+
+ /**
+ * Returns the number of distinct outcomes. Requires the model to have been trained!
+ *
+ * @return
+ */
+ public int getNumOutcomes() {
+ if (classifier == null)
+ train();
+ return pipes.getTargetAlphabet().size();
+ }
+ }
+
+ public static void example(String[] args) throws IOException, ClassNotFoundException {
+
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ Alphabet dataAlphabet = new Alphabet();
+ LabelAlphabet labelAlphabet = new LabelAlphabet();
+
+ pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+ // Basically, SvmLight but with a custom (fixed) alphabet)
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+
+ FileReader reader1 = new FileReader("data.1");
+ FileReader reader2 = new FileReader("data.2");
+
+ SerialPipes pipes = new SerialPipes(pipeList);
+ InstanceList instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
+ instances.addThruPipe(new CsvIterator(reader1, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
+ ClassifierTrainer trainer1 = new MaxEntTrainer();
+ Classifier classifier1 = trainer1.train(instances);
+
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(dataAlphabet, labelAlphabet);
+ instances.setPipe(pipes);
+ instances.addThruPipe(new CsvIterator(reader2, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
+ ClassifierTrainer trainer2 = new MaxEntTrainer();
+ Classifier classifier2 = trainer2.train(instances);
+ }
+
+ public static void main(String[] args) throws IOException, ClassNotFoundException {
+ LexicalSharpener ts = new LexicalSharpener(null, args, null);
+
+ String modelFile = "model";
+
+ if (args.length > 0) {
+ String dataFile = args[0];
+
+ System.err.println("Training model from file " + dataFile);
+ ts.trainAll(dataFile);
+
+// if (args.length > 1)
+// modelFile = args[1];
+//
+// System.err.println("Writing model to file " + modelFile);
+// ts.saveClassifiers(modelFile);
+// } else {
+// System.err.println("Loading model from file " + modelFile);
+// ts.loadClassifiers(modelFile);
+ }
+
+ Scanner stdin = new Scanner(System.in);
+ while(stdin.hasNextLine()) {
+ String line = stdin.nextLine();
+ String[] tokens = line.split(" ", 3);
+ String sourceWord = tokens[0];
+ String targetWord = tokens[1];
+ String features = tokens[2];
+ Classification result = ts.predict(Vocabulary.id(sourceWord), Vocabulary.id(targetWord), features);
+ if (result != null)
+ System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+ else
+ System.out.println("i got nothing");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/a86ae8e8/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
deleted file mode 100644
index 7982db4..0000000
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ /dev/null
@@ -1,357 +0,0 @@
-package joshua.decoder.ff.morph;
-
-/***
- * This feature function scores a rule application by predicting, for each target word aligned with
- * a source word, how likely the lexical translation is in context.
- *
- * The feature function can be provided with a trained model or a raw training file which it will
- * then train prior to decoding.
- *
- * Format of training file:
- *
- * source_word target_word feature:value feature:value feature:value ...
- *
- * Invocation:
- *
- * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data
- */
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.FileOutputStream;
-import java.io.FileReader;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.StringReader;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Scanner;
-
-import cc.mallet.classify.*;
-import cc.mallet.pipe.*;
-import cc.mallet.pipe.iterator.CsvIterator;
-import cc.mallet.types.Alphabet;
-import cc.mallet.types.Instance;
-import cc.mallet.types.InstanceList;
-import cc.mallet.types.LabelAlphabet;
-import joshua.corpus.Vocabulary;
-import joshua.decoder.Decoder;
-import joshua.decoder.JoshuaConfiguration;
-import joshua.decoder.chart_parser.SourcePath;
-import joshua.decoder.ff.FeatureVector;
-import joshua.decoder.ff.StatelessFF;
-import joshua.decoder.ff.state_maintenance.DPState;
-import joshua.decoder.ff.tm.Rule;
-import joshua.decoder.hypergraph.HGNode;
-import joshua.decoder.segment_file.Sentence;
-import joshua.decoder.segment_file.Token;
-import joshua.util.io.LineReader;
-
-public class LexicalSharpener extends StatelessFF {
-
- private HashMap<Integer,Predictor> classifiers = null;
- public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
- super(weights, "LexicalSharpener", args, config);
-
- if (parsedArgs.getOrDefault("training-data", null) != null) {
- try {
- trainAll(parsedArgs.get("training-data"));
- } catch (FileNotFoundException e) {
- System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
- System.exit(1);
- }
- }
- }
-
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public void trainAll(String dataFile) throws FileNotFoundException {
-
- classifiers = new HashMap<Integer, Predictor>();
-
- Decoder.LOG(1, "Reading " + dataFile);
- LineReader lineReader = null;
- try {
- lineReader = new LineReader(dataFile, true);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
-
- String lastSourceWord = null;
- String examples = "";
- int linesRead = 0;
- for (String line : lineReader) {
- String sourceWord = line.substring(0, line.indexOf(' '));
- if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
- // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
- examples = "";
- }
-
- examples += line + "\n";
- lastSourceWord = sourceWord;
- linesRead++;
- }
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
-
- System.err.println(String.format("Read %d lines from training file", linesRead));
- }
-
- public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
- ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifiers = (HashMap<Integer,Predictor>) ois.readObject();
- ois.close();
-
- System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
- for (int key: classifiers.keySet()) {
- System.err.println(" " + key);
- }
- }
-
- public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
- ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
- oos.writeObject(classifiers);
- oos.close();
- }
-
- /**
- * Compute features. This works by walking over the target side phrase pieces, looking for every
- * word with a single source-aligned word. We then throw the annotations from that source word
- * into our prediction model to learn how much it likes the chosen word. Presumably the source-
- * language annotations have contextual features, so this effectively chooses the words in context.
- */
- @Override
- public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
- Sentence sentence, Accumulator acc) {
-
- System.err.println(String.format("RULE: %s", rule));
-
- Map<Integer, List<Integer>> points = rule.getAlignmentMap();
- for (int t: points.keySet()) {
- List<Integer> source_indices = points.get(t);
- if (source_indices.size() != 1)
- continue;
-
- int targetID = rule.getEnglish()[t];
- String targetWord = Vocabulary.word(targetID);
- int s = i + source_indices.get(0);
- Token sourceToken = sentence.getTokens().get(s);
- String featureString = sourceToken.getAnnotationString().replace('|', ' ');
-
- Classification result = predict(sourceToken.getWord(), targetID, featureString);
- System.out.println("RESULT: " + result.getLabeling());
- if (result.bestLabelIsCorrect()) {
- acc.add(String.format("%s_match", name), 1);
- }
- }
-
- return null;
- }
-
- public Classification predict(int sourceID, int targetID, String featureString) {
- String word = Vocabulary.word(sourceID);
- if (classifiers.containsKey(sourceID)) {
- Predictor predictor = classifiers.get(sourceID);
- if (predictor != null)
- return predictor.predict(Vocabulary.word(targetID), featureString);
- }
-
- return null;
- }
-
- /**
- * Returns an array parallel to the source words array indicating, for each index, the absolute
- * position of that word into the source sentence. For example, for the rule with source side
- *
- * [ 17, 142, -14, 9 ]
- *
- * and source sentence
- *
- * [ 17, 18, 142, 1, 1, 9, 8 ]
- *
- * it will return
- *
- * [ 0, 2, -14, 5 ]
- *
- * which indicates that the first, second, and fourth words of the rule are anchored to the
- * first, third, and sixth words of the input sentence.
- *
- * @param rule
- * @param tailNodes
- * @param start
- * @return a list of alignment points anchored to the source sentence
- */
- public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
- int[] source = rule.getFrench();
-
- // Map the source words in the rule to absolute positions in the sentence
- int[] anchoredSource = source.clone();
-
- int sourceIndex = start;
- int tailNodeIndex = 0;
- for (int i = 0; i < source.length; i++) {
- if (source[i] < 0) { // nonterminal
- anchoredSource[i] = source[i];
- sourceIndex = tailNodes.get(tailNodeIndex).j;
- tailNodeIndex++;
- } else { // terminal
- anchoredSource[i] = sourceIndex;
- sourceIndex++;
- }
- }
-
- return anchoredSource;
- }
-
- public class Predictor {
-
- private SerialPipes pipes = null;
- private InstanceList instances = null;
- private String sourceWord = null;
- private String examples = null;
- private Classifier classifier = null;
-
- public Predictor(String word, String examples) {
- this.sourceWord = word;
- this.examples = examples;
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(pipes);
- }
-
- /**
- * Returns a Classification object a list of features. Uses "which" to determine which classifier
- * to use.
- *
- * @param which the classifier to use
- * @param features the set of features
- * @return
- */
- public Classification predict(String outcome, String features) {
- Instance instance = new Instance(features, outcome, null, null);
- System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
- System.err.println("PREDICT features = " + (String) instance.getData());
-
- if (classifier == null)
- train();
-
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
- return result;
- }
-
- public void train() {
-// System.err.println(String.format("Word %s: training model", sourceWord));
-// System.err.println(String.format(" Examples: %s", examples));
-
- StringReader reader = new StringReader(examples);
-
- // Constructs an instance with everything shoved into the data field
- instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- classifier = trainer.train(instances);
-
- System.err.println(String.format("Trained a model for %s with %d outcomes",
- sourceWord, pipes.getTargetAlphabet().size()));
- }
-
- /**
- * Returns the number of distinct outcomes. Requires the model to have been trained!
- *
- * @return
- */
- public int getNumOutcomes() {
- if (classifier == null)
- train();
- return pipes.getTargetAlphabet().size();
- }
- }
-
- public static void example(String[] args) throws IOException, ClassNotFoundException {
-
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- Alphabet dataAlphabet = new Alphabet();
- LabelAlphabet labelAlphabet = new LabelAlphabet();
-
- pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
- // Basically, SvmLight but with a custom (fixed) alphabet)
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
-
- FileReader reader1 = new FileReader("data.1");
- FileReader reader2 = new FileReader("data.2");
-
- SerialPipes pipes = new SerialPipes(pipeList);
- InstanceList instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
- instances.addThruPipe(new CsvIterator(reader1, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
- ClassifierTrainer trainer1 = new MaxEntTrainer();
- Classifier classifier1 = trainer1.train(instances);
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
- instances.addThruPipe(new CsvIterator(reader2, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
- ClassifierTrainer trainer2 = new MaxEntTrainer();
- Classifier classifier2 = trainer2.train(instances);
- }
-
- public static void main(String[] args) throws IOException, ClassNotFoundException {
- LexicalSharpener ts = new LexicalSharpener(null, args, null);
-
- String modelFile = "model";
-
- if (args.length > 0) {
- String dataFile = args[0];
-
- System.err.println("Training model from file " + dataFile);
- ts.trainAll(dataFile);
-
-// if (args.length > 1)
-// modelFile = args[1];
-//
-// System.err.println("Writing model to file " + modelFile);
-// ts.saveClassifiers(modelFile);
-// } else {
-// System.err.println("Loading model from file " + modelFile);
-// ts.loadClassifiers(modelFile);
- }
-
- Scanner stdin = new Scanner(System.in);
- while(stdin.hasNextLine()) {
- String line = stdin.nextLine();
- String[] tokens = line.split(" ", 3);
- String sourceWord = tokens[0];
- String targetWord = tokens[1];
- String features = tokens[2];
- Classification result = ts.predict(Vocabulary.id(sourceWord), Vocabulary.id(targetWord), features);
- if (result != null)
- System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
- else
- System.out.println("i got nothing");
- }
- }
-}
[08/13] incubator-joshua git commit: Rewrote in less efficient way,
now works!
Posted by mj...@apache.org.
Rewrote in less efficient way, now works!
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/b121fc20
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/b121fc20
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/b121fc20
Branch: refs/heads/morph
Commit: b121fc20c22dd6752ccf3e728ed8c23ca73eb9aa
Parents: 23306b4
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 17:59:50 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 17:59:50 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 267 ++++++++-----------
1 file changed, 110 insertions(+), 157 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b121fc20/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index a66f4a1..701a72f 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -55,42 +55,16 @@ import joshua.util.io.LineReader;
public class LexicalSharpener extends StatelessFF {
- private HashMap<Integer,Object> classifiers = null;
- private SerialPipes pipes = null;
- private InstanceList instances = null;
- private LabelAlphabet labelAlphabet = null;
- private Alphabet dataAlphabet = null;
-
+ private HashMap<Integer,Predictor> classifiers = null;
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
-
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- dataAlphabet = new Alphabet();
- labelAlphabet = new LabelAlphabet();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
- // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
- pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
-
- pipes = new SerialPipes(pipeList);
- instances = new InstanceList(dataAlphabet, labelAlphabet);
- instances.setPipe(pipes);
if (parsedArgs.containsKey("model")) {
String modelFile = parsedArgs.get("model");
if (! new File(modelFile).exists()) {
if (parsedArgs.getOrDefault("training-data", null) != null) {
try {
- train(parsedArgs.get("training-data"));
+ trainAll(parsedArgs.get("training-data"));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
@@ -113,49 +87,9 @@ public class LexicalSharpener extends StatelessFF {
}
}
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public void train(String dataFile) throws FileNotFoundException {
-
- classifiers = new HashMap<Integer, Object>();
-
- Decoder.VERBOSE = 1;
- LineReader lineReader = null;
- try {
- lineReader = new LineReader(dataFile, true);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
-
- String lastOutcome = null;
- String buffer = "";
- int linesRead = 0;
- for (String line : lineReader) {
- String outcome = line.substring(0, line.indexOf(' '));
- if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
- classifiers.put(Vocabulary.id(lastOutcome), buffer);
-// System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
- buffer = "";
- }
-
- buffer += line + "\n";
- lastOutcome = outcome;
- linesRead++;
- }
- classifiers.put(Vocabulary.id(lastOutcome), buffer);
-
- System.err.println(String.format("Read %d lines from training file", linesRead));
- }
-
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
- classifiers = (HashMap<Integer,Object>) ois.readObject();
+ classifiers = (HashMap<Integer,Predictor>) ois.readObject();
ois.close();
System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
@@ -171,53 +105,6 @@ public class LexicalSharpener extends StatelessFF {
}
/**
- * Returns a Classification object a list of features. Uses "which" to determine which classifier
- * to use.
- *
- * @param which the classifier to use
- * @param features the set of features
- * @return
- */
- public Classification predict(String which, String features) {
- Instance instance = new Instance(features, which, which, which);
-// System.err.println("PREDICT outcome = " + (String) instance.getTarget());
-// System.err.println("PREDICT features = " + (String) instance.getData());
-
- Classifier classifier = getClassifier(which);
- if (classifier != null) {
- Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
- return result;
- }
-
- return null;
- }
-
- public Classifier getClassifier(String which) {
- int id = Vocabulary.id(which);
-
- if (classifiers.containsKey(id)) {
- if (classifiers.get(id) instanceof String) {
- StringReader reader = new StringReader((String)classifiers.get(id));
-
- System.err.println("training new classifier from string for " + which);
- System.err.println(String.format("training string is: '%s'", (String)classifiers.get(id)));
-
- // Constructs an instance with everything shoved into the data field
- instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
-
- ClassifierTrainer trainer = new MaxEntTrainer();
- Classifier classifier = trainer.train(instances);
-
- classifiers.put(id, classifier);
- }
-
- return (Classifier) classifiers.get(id);
- }
-
- return null;
- }
-
- /**
* Compute features. This works by walking over the target side phrase pieces, looking for every
* word with a single source-aligned word. We then throw the annotations from that source word
* into our prediction model to learn how much it likes the chosen word. Presumably the source-
@@ -238,7 +125,7 @@ public class LexicalSharpener extends StatelessFF {
Token sourceToken = sentence.getTokens().get(s);
String featureString = sourceToken.getAnnotationString().replace('|', ' ');
- Classification result = predict(targetWord, featureString);
+ Classification result = predict(sourceToken.getWord(), featureString);
if (result.bestLabelIsCorrect()) {
acc.add(String.format("%s_match", name), 1);
}
@@ -247,6 +134,17 @@ public class LexicalSharpener extends StatelessFF {
return null;
}
+ public Classification predict(int id, String featureString) {
+ String word = Vocabulary.word(id);
+ if (classifiers.containsKey(id)) {
+ Predictor predictor = classifiers.get(id);
+ if (predictor != null)
+ return predictor.predict(featureString);
+ }
+
+ return null;
+ }
+
/**
* Returns an array parallel to the source words array indicating, for each index, the absolute
* position of that word into the source sentence. For example, for the rule with source side
@@ -291,54 +189,109 @@ public class LexicalSharpener extends StatelessFF {
return anchoredSource;
}
- public static class CustomLineProcessor extends Pipe {
+ public class Predictor {
+
+ private SerialPipes pipes = null;
+ private InstanceList instances = null;
+ private String sourceWord = null;
+ private String examples = null;
+ private Classifier classifier = null;
+
+ public Predictor(String word, String examples) {
+ this.sourceWord = word;
+ this.examples = examples;
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
- private static final long serialVersionUID = 1L;
+ pipes = new SerialPipes(pipeList);
+ instances = new InstanceList(pipes);
+ }
+
+ /**
+ * Returns a Classification object a list of features. Uses "which" to determine which classifier
+ * to use.
+ *
+ * @param which the classifier to use
+ * @param features the set of features
+ * @return
+ */
+ public Classification predict(String features) {
+ Instance instance = new Instance(features, null, null, null);
+ // System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
+ // System.err.println("PREDICT features = " + (String) instance.getData());
+
+ if (classifier == null)
+ train();
- public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
- super(data, label);
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+ return result;
}
-
- @Override
- public Instance pipe(Instance carrier) {
- // we expect the data for each instance to be
- // a line from the SVMLight format text file
- String dataStr = (String)carrier.getData();
- String[] tokens = dataStr.split("\\s+");
+ public void train() {
+ System.err.println(String.format("Word %s: training model", sourceWord));
+ System.err.println(String.format(" Examples: %s", examples));
+
+ StringReader reader = new StringReader(examples);
-// carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+ // Constructs an instance with everything shoved into the data field
+ instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
- // the rest are feature-value pairs
- ArrayList<Integer> indices = new ArrayList<Integer>();
- ArrayList<Double> values = new ArrayList<Double>();
- for (int termIndex = 0; termIndex < tokens.length; termIndex++) {
- if (!tokens[termIndex].equals("")) {
- String[] s = tokens[termIndex].split(":");
- if (s.length != 2) {
- throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
- }
- String feature = s[0];
- int index = getDataAlphabet().lookupIndex(feature, true);
- indices.add(index);
- values.add(Double.parseDouble(s[1]));
- }
- }
+ ClassifierTrainer trainer = new MaxEntTrainer();
+ classifier = trainer.train(instances);
+ }
+ }
+
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public void trainAll(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Predictor>();
- assert(indices.size() == values.size());
- int[] indicesArr = new int[indices.size()];
- double[] valuesArr = new double[values.size()];
- for (int i = 0; i < indicesArr.length; i++) {
- indicesArr[i] = indices.get(i);
- valuesArr[i] = values.get(i);
+ Decoder.VERBOSE = 1;
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastSourceWord = null;
+ String examples = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String sourceWord = line.substring(0, line.indexOf(' '));
+ if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ examples = "";
}
- cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
- carrier.setData(fv);
- return carrier;
+ examples += line + "\n";
+ lastSourceWord = sourceWord;
+ linesRead++;
}
- }
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ System.err.println(String.format("Read %d lines from training file", linesRead));
+ }
+
public static void example(String[] args) throws IOException, ClassNotFoundException {
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
@@ -348,7 +301,7 @@ public class LexicalSharpener extends StatelessFF {
pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
// Basically, SvmLight but with a custom (fixed) alphabet)
- pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
FileReader reader1 = new FileReader("data.1");
FileReader reader2 = new FileReader("data.2");
@@ -377,7 +330,7 @@ public class LexicalSharpener extends StatelessFF {
String dataFile = args[0];
System.err.println("Training model from file " + dataFile);
- ts.train(dataFile);
+ ts.trainAll(dataFile);
// if (args.length > 1)
// modelFile = args[1];
@@ -395,7 +348,7 @@ public class LexicalSharpener extends StatelessFF {
String[] tokens = line.split(" ", 2);
String sourceWord = tokens[0];
String features = tokens[1];
- Classification result = ts.predict(sourceWord, features);
+ Classification result = ts.predict(Vocabulary.id(sourceWord), features);
if (result != null)
System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
else
[09/13] incubator-joshua git commit: fiddling
Posted by mj...@apache.org.
fiddling
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/f0728fac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/f0728fac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/f0728fac
Branch: refs/heads/morph
Commit: f0728fac70bbacc7b33937f4915a091e4cfe4e36
Parents: b121fc2
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:11:54 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:11:54 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/LexicalSharpener.java | 159 +++++++++----------
1 file changed, 79 insertions(+), 80 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/f0728fac/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index 701a72f..7982db4 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,7 +1,5 @@
package joshua.decoder.ff.morph;
-import java.io.BufferedReader;
-
/***
* This feature function scores a rule application by predicting, for each target word aligned with
* a source word, how likely the lexical translation is in context.
@@ -59,34 +57,56 @@ public class LexicalSharpener extends StatelessFF {
public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, "LexicalSharpener", args, config);
- if (parsedArgs.containsKey("model")) {
- String modelFile = parsedArgs.get("model");
- if (! new File(modelFile).exists()) {
- if (parsedArgs.getOrDefault("training-data", null) != null) {
- try {
- trainAll(parsedArgs.get("training-data"));
- } catch (FileNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- } else {
- System.err.println("* FATAL: no model and no training data.");
- System.exit(1);
- }
- } else {
- try {
- loadClassifiers(modelFile);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
+ if (parsedArgs.getOrDefault("training-data", null) != null) {
+ try {
+ trainAll(parsedArgs.get("training-data"));
+ } catch (FileNotFoundException e) {
+ System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
+ System.exit(1);
}
}
}
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
+ public void trainAll(String dataFile) throws FileNotFoundException {
+
+ classifiers = new HashMap<Integer, Predictor>();
+
+ Decoder.LOG(1, "Reading " + dataFile);
+ LineReader lineReader = null;
+ try {
+ lineReader = new LineReader(dataFile, true);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ String lastSourceWord = null;
+ String examples = "";
+ int linesRead = 0;
+ for (String line : lineReader) {
+ String sourceWord = line.substring(0, line.indexOf(' '));
+ if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+ // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+ examples = "";
+ }
+
+ examples += line + "\n";
+ lastSourceWord = sourceWord;
+ linesRead++;
+ }
+ classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+
+ System.err.println(String.format("Read %d lines from training file", linesRead));
+ }
+
public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
classifiers = (HashMap<Integer,Predictor>) ois.readObject();
@@ -113,6 +133,8 @@ public class LexicalSharpener extends StatelessFF {
@Override
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
+
+ System.err.println(String.format("RULE: %s", rule));
Map<Integer, List<Integer>> points = rule.getAlignmentMap();
for (int t: points.keySet()) {
@@ -120,12 +142,14 @@ public class LexicalSharpener extends StatelessFF {
if (source_indices.size() != 1)
continue;
- String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+ int targetID = rule.getEnglish()[t];
+ String targetWord = Vocabulary.word(targetID);
int s = i + source_indices.get(0);
Token sourceToken = sentence.getTokens().get(s);
String featureString = sourceToken.getAnnotationString().replace('|', ' ');
- Classification result = predict(sourceToken.getWord(), featureString);
+ Classification result = predict(sourceToken.getWord(), targetID, featureString);
+ System.out.println("RESULT: " + result.getLabeling());
if (result.bestLabelIsCorrect()) {
acc.add(String.format("%s_match", name), 1);
}
@@ -134,12 +158,12 @@ public class LexicalSharpener extends StatelessFF {
return null;
}
- public Classification predict(int id, String featureString) {
- String word = Vocabulary.word(id);
- if (classifiers.containsKey(id)) {
- Predictor predictor = classifiers.get(id);
+ public Classification predict(int sourceID, int targetID, String featureString) {
+ String word = Vocabulary.word(sourceID);
+ if (classifiers.containsKey(sourceID)) {
+ Predictor predictor = classifiers.get(sourceID);
if (predictor != null)
- return predictor.predict(featureString);
+ return predictor.predict(Vocabulary.word(targetID), featureString);
}
return null;
@@ -226,10 +250,10 @@ public class LexicalSharpener extends StatelessFF {
* @param features the set of features
* @return
*/
- public Classification predict(String features) {
- Instance instance = new Instance(features, null, null, null);
- // System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
- // System.err.println("PREDICT features = " + (String) instance.getData());
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, outcome, null, null);
+ System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
+ System.err.println("PREDICT features = " + (String) instance.getData());
if (classifier == null)
train();
@@ -239,8 +263,8 @@ public class LexicalSharpener extends StatelessFF {
}
public void train() {
- System.err.println(String.format("Word %s: training model", sourceWord));
- System.err.println(String.format(" Examples: %s", examples));
+// System.err.println(String.format("Word %s: training model", sourceWord));
+// System.err.println(String.format(" Examples: %s", examples));
StringReader reader = new StringReader(examples);
@@ -249,47 +273,21 @@ public class LexicalSharpener extends StatelessFF {
ClassifierTrainer trainer = new MaxEntTrainer();
classifier = trainer.train(instances);
- }
- }
-
- /**
- * Trains a maxent classifier from the provided training data, returning a Mallet model.
- *
- * @param dataFile
- * @return
- * @throws FileNotFoundException
- */
- public void trainAll(String dataFile) throws FileNotFoundException {
-
- classifiers = new HashMap<Integer, Predictor>();
-
- Decoder.VERBOSE = 1;
- LineReader lineReader = null;
- try {
- lineReader = new LineReader(dataFile, true);
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
+
+ System.err.println(String.format("Trained a model for %s with %d outcomes",
+ sourceWord, pipes.getTargetAlphabet().size()));
}
- String lastSourceWord = null;
- String examples = "";
- int linesRead = 0;
- for (String line : lineReader) {
- String sourceWord = line.substring(0, line.indexOf(' '));
- if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
- // System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
- examples = "";
- }
-
- examples += line + "\n";
- lastSourceWord = sourceWord;
- linesRead++;
+ /**
+ * Returns the number of distinct outcomes. Requires the model to have been trained!
+ *
+ * @return
+ */
+ public int getNumOutcomes() {
+ if (classifier == null)
+ train();
+ return pipes.getTargetAlphabet().size();
}
- classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
-
- System.err.println(String.format("Read %d lines from training file", linesRead));
}
public static void example(String[] args) throws IOException, ClassNotFoundException {
@@ -345,10 +343,11 @@ public class LexicalSharpener extends StatelessFF {
Scanner stdin = new Scanner(System.in);
while(stdin.hasNextLine()) {
String line = stdin.nextLine();
- String[] tokens = line.split(" ", 2);
+ String[] tokens = line.split(" ", 3);
String sourceWord = tokens[0];
- String features = tokens[1];
- Classification result = ts.predict(Vocabulary.id(sourceWord), features);
+ String targetWord = tokens[1];
+ String features = tokens[2];
+ Classification result = ts.predict(Vocabulary.id(sourceWord), Vocabulary.id(targetWord), features);
if (result != null)
System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
else
[02/13] incubator-joshua git commit: Added note about usage and data
format
Posted by mj...@apache.org.
Added note about usage and data format
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/dca7a5dc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/dca7a5dc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/dca7a5dc
Branch: refs/heads/morph
Commit: dca7a5dc6ee3690132d3e4c0be1eef2f69bf3115
Parents: e3ad1a6
Author: Matt Post <po...@cs.jhu.edu>
Authored: Wed Apr 20 17:21:44 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Wed Apr 20 17:21:44 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/InflectionPredictor.java | 17 +++++++++++------
1 file changed, 11 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dca7a5dc/src/joshua/decoder/ff/morph/InflectionPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/InflectionPredictor.java b/src/joshua/decoder/ff/morph/InflectionPredictor.java
index 82d52ea..5282497 100644
--- a/src/joshua/decoder/ff/morph/InflectionPredictor.java
+++ b/src/joshua/decoder/ff/morph/InflectionPredictor.java
@@ -1,19 +1,24 @@
package joshua.decoder.ff.morph;
+/*
+ * Format of training file:
+ *
+ * source_word target_word feature:value feature:value feature:value ...
+ *
+ * Invocation:
+ *
+ * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.InflectionPredictor /path/to/training/data
+ */
+
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
-import cc.mallet.classify.Classifier;
-import cc.mallet.classify.ClassifierTrainer;
-import cc.mallet.classify.MaxEntTrainer;
-import cc.mallet.classify.NaiveBayesTrainer;
+import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.CsvIterator;
-import cc.mallet.pipe.iterator.FileIterator;
-import cc.mallet.pipe.iterator.LineIterator;
import cc.mallet.types.InstanceList;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.SourcePath;
[03/13] incubator-joshua git commit: Added full-file training,
start of feature function
Posted by mj...@apache.org.
Added full-file training, start of feature function
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/47f1af58
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/47f1af58
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/47f1af58
Branch: refs/heads/morph
Commit: 47f1af588a8878963112b11ef29e4180e9d82869
Parents: dca7a5d
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 09:15:09 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 09:15:09 2016 -0400
----------------------------------------------------------------------
.../decoder/ff/morph/InflectionPredictor.java | 193 ++++++++++++++++---
src/joshua/decoder/segment_file/Token.java | 12 +-
2 files changed, 178 insertions(+), 27 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/47f1af58/src/joshua/decoder/ff/morph/InflectionPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/InflectionPredictor.java b/src/joshua/decoder/ff/morph/InflectionPredictor.java
index 5282497..f4a4310 100644
--- a/src/joshua/decoder/ff/morph/InflectionPredictor.java
+++ b/src/joshua/decoder/ff/morph/InflectionPredictor.java
@@ -1,25 +1,40 @@
package joshua.decoder.ff.morph;
-/*
+/***
+ * This feature function scores a rule application by predicting, for each target word aligned with
+ * a source word, how likely the lexical translation is in context.
+ *
+ * The feature function can be provided with a trained model or a raw training file which it will
+ * then train prior to decoding.
+ *
* Format of training file:
*
* source_word target_word feature:value feature:value feature:value ...
*
* Invocation:
*
- * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.InflectionPredictor /path/to/training/data
+ * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data
*/
import java.io.File;
+import java.io.FileInputStream;
import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
import java.io.FileReader;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
import cc.mallet.classify.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
+import joshua.corpus.Vocabulary;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.FeatureVector;
@@ -28,13 +43,31 @@ import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.segment_file.Sentence;
+import joshua.decoder.segment_file.Token;
public class InflectionPredictor extends StatelessFF {
private Classifier classifier = null;
+ private SerialPipes pipes = null;
public InflectionPredictor(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
- super(weights, "InfectionPredictor", args, config);
+ super(weights, "LexicalSharpener", args, config);
+
+ ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+ // I don't know if this is needed
+ pipeList.add(new Target2Label());
+ // Convert SVM-light format to sparse feature vector
+ pipeList.add(new SvmLight2FeatureVectorAndLabel());
+ // Validation
+// pipeList.add(new PrintInputAndTarget());
+
+ // name: english word
+ // data: features (FeatureVector)
+ // target: foreign inflection
+ // source: null
+
+ pipes = new SerialPipes(pipeList);
if (parsedArgs.containsKey("model")) {
String modelFile = parsedArgs.get("model");
@@ -51,28 +84,30 @@ public class InflectionPredictor extends StatelessFF {
System.exit(1);
}
} else {
- // TODO: load the model
+ try {
+ loadClassifier(modelFile);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
}
}
}
+ /**
+ * Trains a maxent classifier from the provided training data, returning a Mallet model.
+ *
+ * @param dataFile
+ * @return
+ * @throws FileNotFoundException
+ */
public Classifier train(String dataFile) throws FileNotFoundException {
- ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
- // I don't know if this is needed
- pipeList.add(new Target2Label());
- // Convert SVM-light format to sparse feature vector
- pipeList.add(new SvmLight2FeatureVectorAndLabel());
- // Validation
-// pipeList.add(new PrintInputAndTarget());
-
- // name: english word
- // data: features (FeatureVector)
- // target: foreign inflection
- // source: null
// Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
- InstanceList instances = new InstanceList(new SerialPipes(pipeList));
+ InstanceList instances = new InstanceList(pipes);
instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
"(\\w+)\\s+(.*)",
2, -1, 1));
@@ -82,22 +117,130 @@ public class InflectionPredictor extends StatelessFF {
return classifier;
}
-
+
+ public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+ ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
+ classifier = (Classifier) ois.readObject();
+ }
+
+ public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+ ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
+ oos.writeObject(classifier);
+ oos.close();
+ }
+
+ public Classification predict(String outcome, String features) {
+ Instance instance = new Instance(features, null, null, null);
+ System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+ System.err.println("PREDICT features = " + (String) instance.getData());
+ Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+
+ return result;
+ }
+
+ /**
+ * Compute features. This works by walking over the target side phrase pieces, looking for every
+ * word with a single source-aligned word. We then throw the annotations from that source word
+ * into our prediction model to learn how much it likes the chosen word. Presumably the source-
+ * language annotations have contextual features, so this effectively chooses the words in context.
+ */
@Override
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
+
+ Map<Integer, List<Integer>> points = rule.getAlignmentMap();
+ for (int t: points.keySet()) {
+ List<Integer> source_indices = points.get(t);
+ if (source_indices.size() != 1)
+ continue;
+
+ String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+ int s = i + source_indices.get(0);
+ Token sourceToken = sentence.getTokens().get(s);
+ String featureString = sourceToken.getAnnotationString().replace('|', ' ');
+
+ Classification result = predict(targetWord, featureString);
+ if (result.bestLabelIsCorrect()) {
+ acc.add(String.format("%s_match", name), 1);
+ }
+ }
return null;
}
- public static void main(String[] args) throws FileNotFoundException {
- InflectionPredictor ip = new InflectionPredictor(null, args, null);
+ /**
+ * Returns an array parallel to the source words array indicating, for each index, the absolute
+ * position of that word into the source sentence. For example, for the rule with source side
+ *
+ * [ 17, 142, -14, 9 ]
+ *
+ * and source sentence
+ *
+ * [ 17, 18, 142, 1, 1, 9, 8 ]
+ *
+ * it will return
+ *
+ * [ 0, 2, -14, 5 ]
+ *
+ * which indicates that the first, second, and fourth words of the rule are anchored to the
+ * first, third, and sixth words of the input sentence.
+ *
+ * @param rule
+ * @param tailNodes
+ * @param start
+ * @return a list of alignment points anchored to the source sentence
+ */
+ public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
+ int[] source = rule.getFrench();
+
+ // Map the source words in the rule to absolute positions in the sentence
+ int[] anchoredSource = source.clone();
- String dataFile = "/Users/post/Desktop/amazon16/model";
- if (args.length > 0)
- dataFile = args[0];
+ int sourceIndex = start;
+ int tailNodeIndex = 0;
+ for (int i = 0; i < source.length; i++) {
+ if (source[i] < 0) { // nonterminal
+ anchoredSource[i] = source[i];
+ sourceIndex = tailNodes.get(tailNodeIndex).j;
+ tailNodeIndex++;
+ } else { // terminal
+ anchoredSource[i] = sourceIndex;
+ sourceIndex++;
+ }
+ }
- ip.train(dataFile);
+ return anchoredSource;
}
+ public static void main(String[] args) throws IOException, ClassNotFoundException {
+ InflectionPredictor ts = new InflectionPredictor(null, args, null);
+
+ String modelFile = "model";
+
+ if (args.length > 0) {
+ String dataFile = args[0];
+
+ System.err.println("Training model from file " + dataFile);
+ ts.train(dataFile);
+
+ if (args.length > 1)
+ modelFile = args[1];
+
+ System.err.println("Writing model to file " + modelFile);
+ ts.saveClassifier(modelFile);
+ } else {
+ System.err.println("Loading model from file " + modelFile);
+ ts.loadClassifier(modelFile);
+ }
+
+ Scanner stdin = new Scanner(System.in);
+ while(stdin.hasNextLine()) {
+ String line = stdin.nextLine();
+ String[] tokens = line.split(" ", 2);
+ String outcome = tokens[0];
+ String features = tokens[1];
+ Classification result = ts.predict(outcome, features);
+ System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/47f1af58/src/joshua/decoder/segment_file/Token.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/segment_file/Token.java b/src/joshua/decoder/segment_file/Token.java
index 12e2b68..7969294 100644
--- a/src/joshua/decoder/segment_file/Token.java
+++ b/src/joshua/decoder/segment_file/Token.java
@@ -36,6 +36,7 @@ public class Token {
private int tokenID;
private HashMap<String,String> annotations = null;
+ private String annotationString;
/**
* Constructor : Creates a Token object from a raw word
@@ -69,9 +70,9 @@ public class Token {
if (tag.find()) {
// Annotation match found
token = tag.group(1);
- String tagStr = tag.group(2);
+ annotationString = tag.group(2);
- for (String annotation: tagStr.split(";")) {
+ for (String annotation: annotationString.split(";")) {
int where = annotation.indexOf("=");
if (where != -1) {
annotations.put(annotation.substring(0, where), annotation.substring(where + 1));
@@ -121,4 +122,11 @@ public class Token {
return null;
}
+
+ /**
+ * Returns the raw annotation string
+ */
+ public String getAnnotationString() {
+ return annotationString;
+ }
}
\ No newline at end of file
[12/13] incubator-joshua git commit: added new libs to invocation
Posted by mj...@apache.org.
added new libs to invocation
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/95ddf09a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/95ddf09a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/95ddf09a
Branch: refs/heads/morph
Commit: 95ddf09ab08f15710ad8b6f1ef4b8c7ccce3dbda
Parents: 61cd489
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:12:39 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:12:39 2016 -0400
----------------------------------------------------------------------
bin/joshua-decoder | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/95ddf09a/bin/joshua-decoder
----------------------------------------------------------------------
diff --git a/bin/joshua-decoder b/bin/joshua-decoder
index cdb2cf4..5574c48 100755
--- a/bin/joshua-decoder
+++ b/bin/joshua-decoder
@@ -27,7 +27,7 @@ set -u
JOSHUA=$(dirname $0)/..
exec java -Xmx${mem} \
- -cp $JOSHUA/class:$JOSHUA/ext/berkeleylm/jar/berkeleylm.jar:$JOSHUA/lib/gson-2.5.jar:$JOSHUA/lib/guava-19.0.jar \
+ -cp $JOSHUA/class:$JOSHUA/ext/berkeleylm/jar/berkeleylm.jar:$JOSHUA/lib/gson-2.5.jar:$JOSHUA/lib/guava-19.0.jar:$JOSHUA/lib/mallet-2.0.7.jar:$JOSHUA/lib/trove4j-2.0.2.jar \
-Dfile.encoding=utf8 \
-Djava.util.logging.config.file=${JOSHUA}/logging.properties \
-Djava.library.path=$JOSHUA/lib \
[10/13] incubator-joshua git commit: bugfix: initialize
annotationString
Posted by mj...@apache.org.
bugfix: initialize annotationString
Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/4ec7ddb1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/4ec7ddb1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/4ec7ddb1
Branch: refs/heads/morph
Commit: 4ec7ddb1084d92a2b93ac81c7901dbe7251d5d02
Parents: f0728fa
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:12:09 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:12:09 2016 -0400
----------------------------------------------------------------------
src/joshua/decoder/segment_file/Token.java | 1 +
1 file changed, 1 insertion(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/4ec7ddb1/src/joshua/decoder/segment_file/Token.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/segment_file/Token.java b/src/joshua/decoder/segment_file/Token.java
index 7969294..9dcec22 100644
--- a/src/joshua/decoder/segment_file/Token.java
+++ b/src/joshua/decoder/segment_file/Token.java
@@ -62,6 +62,7 @@ public class Token {
public Token(String rawWord) {
annotations = new HashMap<String,String>();
+ annotationString = "";
// Matches a word with an annotation
// Check guidelines in constructor description