You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/24 22:53:26 UTC

[05/18] incubator-joshua git commit: Model now serializes

Model now serializes


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/68b01bc1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/68b01bc1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/68b01bc1

Branch: refs/heads/morph
Commit: 68b01bc168298db382334e9f01bdf2992db85b01
Parents: 1c8aaa5
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 15:59:39 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 15:59:39 2016 -0400

----------------------------------------------------------------------
 src/joshua/decoder/ff/LexicalSharpener.java | 184 ++++++-----------------
 src/joshua/decoder/ff/MalletPredictor.java  |  97 ++++++++++++
 2 files changed, 143 insertions(+), 138 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/68b01bc1/src/joshua/decoder/ff/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/LexicalSharpener.java b/src/joshua/decoder/ff/LexicalSharpener.java
index 2c96f83..8671d57 100644
--- a/src/joshua/decoder/ff/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/LexicalSharpener.java
@@ -19,24 +19,16 @@ package joshua.decoder.ff;
 import java.io.FileInputStream;
 import java.io.FileNotFoundException;
 import java.io.FileOutputStream;
-import java.io.FileReader;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
-import java.io.StringReader;
-import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Scanner;
 
 import cc.mallet.classify.*;
-import cc.mallet.pipe.*;
-import cc.mallet.pipe.iterator.CsvIterator;
-import cc.mallet.types.Alphabet;
-import cc.mallet.types.Instance;
-import cc.mallet.types.InstanceList;
-import cc.mallet.types.LabelAlphabet;
+import cc.mallet.types.Labeling;
 import joshua.corpus.Vocabulary;
 import joshua.decoder.Decoder;
 import joshua.decoder.JoshuaConfiguration;
@@ -52,7 +44,8 @@ import joshua.util.io.LineReader;
 
 public class LexicalSharpener extends StatelessFF {
 
-  private HashMap<Integer,Predictor> classifiers = null;
+  private HashMap<String,MalletPredictor> classifiers = null;
+
   public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, "LexicalSharpener", args, config);
 
@@ -63,6 +56,13 @@ public class LexicalSharpener extends StatelessFF {
         System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
         System.exit(1);
       }
+    } else if (parsedArgs.containsKey("model")) {
+      try {
+        loadClassifiers(parsedArgs.get("model"));
+      } catch (ClassNotFoundException | IOException e) {
+        // TODO Auto-generated catch block
+        e.printStackTrace();
+      }
     }
   }
   
@@ -75,7 +75,7 @@ public class LexicalSharpener extends StatelessFF {
    */
   public void trainAll(String dataFile) throws FileNotFoundException {
   
-    classifiers = new HashMap<Integer, Predictor>();
+    classifiers = new HashMap<String, MalletPredictor>();
 
     Decoder.LOG(1, "Reading " + dataFile);
     LineReader lineReader = null;
@@ -92,7 +92,7 @@ public class LexicalSharpener extends StatelessFF {
     for (String line : lineReader) {
       String sourceWord = line.substring(0, line.indexOf(' '));
       if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
-        classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+        classifiers.put(lastSourceWord, new MalletPredictor(lastSourceWord, examples));
         //        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
         examples = "";
       }
@@ -101,18 +101,18 @@ public class LexicalSharpener extends StatelessFF {
       lastSourceWord = sourceWord;
       linesRead++;
     }
-    classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+    classifiers.put(lastSourceWord, new MalletPredictor(lastSourceWord, examples));
   
     System.err.println(String.format("Read %d lines from training file", linesRead));
   }
 
   public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
     ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
-    classifiers = (HashMap<Integer,Predictor>) ois.readObject();
+    classifiers = (HashMap<String,MalletPredictor>) ois.readObject();
     ois.close();
     
     System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
-    for (int key: classifiers.keySet()) {
+    for (String key: classifiers.keySet()) {
       System.err.println("  " + key);
     }
   }
@@ -133,8 +133,6 @@ public class LexicalSharpener extends StatelessFF {
   public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
       Sentence sentence, Accumulator acc) {
     
-    System.err.println(String.format("RULE: %s",  rule));
-        
     Map<Integer, List<Integer>> points = rule.getAlignmentMap();
     for (int t: points.keySet()) {
       List<Integer> source_indices = points.get(t);
@@ -142,27 +140,46 @@ public class LexicalSharpener extends StatelessFF {
         continue;
       
       int targetID = rule.getEnglish()[t];
-      String targetWord = Vocabulary.word(targetID);
       int s = i + source_indices.get(0);
       Token sourceToken = sentence.getTokens().get(s);
       String featureString = sourceToken.getAnnotationString().replace('|', ' ');
       
       Classification result = predict(sourceToken.getWord(), targetID, featureString);
-      System.out.println("RESULT: " + result.getLabeling());
-      if (result.bestLabelIsCorrect()) {
-        acc.add(String.format("%s_match", name), 1);
+      if (result != null) {
+        Labeling labeling = result.getLabeling();
+        int num = labeling.numLocations();
+        int predicted = Vocabulary.id(labeling.getBestLabel().toString());
+//        System.err.println(String.format("LexicalSharpener: predicted %s (rule %s) %.5f",
+//            labeling.getBestLabel().toString(), Vocabulary.word(targetID), Math.log(labeling.getBestValue())));
+        if (num > 1 && predicted == targetID) {
+          acc.add(String.format("%s_match_%s", name, getBin(num)), 1);
+        }
+        acc.add(String.format("%s_weight", name), (float) Math.log(labeling.getBestValue()));
       }
     }
     
     return null;
   }
   
+  private String getBin(int num) {
+    if (num == 2)
+      return "2";
+    else if (num <= 5)
+      return "3-5";
+    else if (num <= 10)
+      return "6-10";
+    else if (num <= 20)
+      return "11-20";
+    else
+      return "21+";
+  }
+  
   public Classification predict(int sourceID, int targetID, String featureString) {
     String word = Vocabulary.word(sourceID);
-    if (classifiers.containsKey(sourceID)) {
-      Predictor predictor = classifiers.get(sourceID);
+    if (classifiers.containsKey(word)) {
+      MalletPredictor predictor = classifiers.get(word);
       if (predictor != null)
-        return predictor.predict(Vocabulary.word(targetID), featureString);
+        return predictor.predict(word, featureString);
     }
 
     return null;
@@ -212,112 +229,6 @@ public class LexicalSharpener extends StatelessFF {
     return anchoredSource;
   }
   
-  public class Predictor {
-    
-    private SerialPipes pipes = null;
-    private InstanceList instances = null;
-    private String sourceWord = null;
-    private String examples = null;
-    private Classifier classifier = null;
-    
-    public Predictor(String word, String examples) {
-      this.sourceWord = word;
-      this.examples = examples;
-      ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
-      // I don't know if this is needed
-      pipeList.add(new Target2Label());
-      // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
-      pipeList.add(new SvmLight2FeatureVectorAndLabel());
-      // Validation
-//      pipeList.add(new PrintInputAndTarget());
-      
-      // name: english word
-      // data: features (FeatureVector)
-      // target: foreign inflection
-      // source: null
-
-      pipes = new SerialPipes(pipeList);
-      instances = new InstanceList(pipes);
-    }
-
-    /**
-       * Returns a Classification object a list of features. Uses "which" to determine which classifier
-       * to use.
-       *   
-       * @param which the classifier to use
-       * @param features the set of features
-       * @return
-       */
-    public Classification predict(String outcome, String features) {
-      Instance instance = new Instance(features, outcome, null, null);
-      System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
-      System.err.println("PREDICT features = " + (String) instance.getData());
-
-      if (classifier == null)
-        train();
-
-      Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
-      return result;
-    }
-
-    public void train() {
-//      System.err.println(String.format("Word %s: training model", sourceWord));
-//      System.err.println(String.format("  Examples: %s", examples));
-      
-      StringReader reader = new StringReader(examples);
-
-      // Constructs an instance with everything shoved into the data field
-      instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
-
-      ClassifierTrainer trainer = new MaxEntTrainer();
-      classifier = trainer.train(instances);
-      
-      System.err.println(String.format("Trained a model for %s with %d outcomes", 
-          sourceWord, pipes.getTargetAlphabet().size()));
-    }
-
-    /**
-     * Returns the number of distinct outcomes. Requires the model to have been trained!
-     * 
-     * @return
-     */
-    public int getNumOutcomes() {
-      if (classifier == null)
-        train();
-      return pipes.getTargetAlphabet().size();
-    }
-  }
-  
-  public static void example(String[] args) throws IOException, ClassNotFoundException {
-
-    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
-    Alphabet dataAlphabet = new Alphabet();
-    LabelAlphabet labelAlphabet = new LabelAlphabet();
-    
-    pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
-    // Basically, SvmLight but with a custom (fixed) alphabet)
-    pipeList.add(new SvmLight2FeatureVectorAndLabel());
-
-    FileReader reader1 = new FileReader("data.1");
-    FileReader reader2 = new FileReader("data.2");
-
-    SerialPipes pipes = new SerialPipes(pipeList);
-    InstanceList instances = new InstanceList(dataAlphabet, labelAlphabet);
-    instances.setPipe(pipes);
-    instances.addThruPipe(new CsvIterator(reader1, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
-    ClassifierTrainer trainer1 = new MaxEntTrainer();
-    Classifier classifier1 = trainer1.train(instances);
-    
-    pipes = new SerialPipes(pipeList);
-    instances = new InstanceList(dataAlphabet, labelAlphabet);
-    instances.setPipe(pipes);
-    instances.addThruPipe(new CsvIterator(reader2, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
-    ClassifierTrainer trainer2 = new MaxEntTrainer();
-    Classifier classifier2 = trainer2.train(instances);
-  }
-  
   public static void main(String[] args) throws IOException, ClassNotFoundException {
     LexicalSharpener ts = new LexicalSharpener(null, args, null);
     
@@ -329,14 +240,11 @@ public class LexicalSharpener extends StatelessFF {
       System.err.println("Training model from file " + dataFile);
       ts.trainAll(dataFile);
     
-//      if (args.length > 1)
-//        modelFile = args[1];
-//      
-//      System.err.println("Writing model to file " + modelFile); 
-//      ts.saveClassifiers(modelFile);
-//    } else {
-//      System.err.println("Loading model from file " + modelFile);
-//      ts.loadClassifiers(modelFile);
+      if (args.length > 1)
+        modelFile = args[1];
+      
+      System.err.println("Writing model to file " + modelFile); 
+      ts.saveClassifiers(modelFile);
     }
     
     Scanner stdin = new Scanner(System.in);

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/68b01bc1/src/joshua/decoder/ff/MalletPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/MalletPredictor.java b/src/joshua/decoder/ff/MalletPredictor.java
new file mode 100644
index 0000000..04c9d8c
--- /dev/null
+++ b/src/joshua/decoder/ff/MalletPredictor.java
@@ -0,0 +1,97 @@
+package joshua.decoder.ff;
+
+import java.io.Serializable;
+import java.io.StringReader;
+import java.util.ArrayList;
+
+import cc.mallet.classify.Classification;
+import cc.mallet.classify.Classifier;
+import cc.mallet.classify.ClassifierTrainer;
+import cc.mallet.classify.MaxEntTrainer;
+import cc.mallet.pipe.Pipe;
+import cc.mallet.pipe.SerialPipes;
+import cc.mallet.pipe.SvmLight2FeatureVectorAndLabel;
+import cc.mallet.pipe.Target2Label;
+import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Instance;
+import cc.mallet.types.InstanceList;
+import joshua.decoder.Decoder;
+
+public class MalletPredictor implements Serializable {
+    
+    private static final long serialVersionUID = 1L;
+
+    private SerialPipes pipes = null;
+    private InstanceList instances = null;
+    private String sourceWord = null;
+    private String examples = null;
+    private Classifier classifier = null;
+    
+    public MalletPredictor(String word, String examples) {
+      this.sourceWord = word;
+      this.examples = examples;
+      ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+      // I don't know if this is needed
+      pipeList.add(new Target2Label());
+      // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+      pipeList.add(new SvmLight2FeatureVectorAndLabel());
+      // Validation
+//      pipeList.add(new PrintInputAndTarget());
+      
+      // name: english word
+      // data: features (FeatureVector)
+      // target: foreign inflection
+      // source: null
+
+      pipes = new SerialPipes(pipeList);
+      instances = new InstanceList(pipes);
+    }
+
+    /**
+       * Returns a Classification object a list of features. Uses "which" to determine which classifier
+       * to use.
+       *   
+       * @param which the classifier to use
+       * @param features the set of features
+       * @return
+       */
+    public Classification predict(String outcome, String features) {
+      Instance instance = new Instance(features, outcome, null, null);
+//      SYSTEM.ERR.PRINTLN("PREDICT TARGETWORD = " + (STRING) INSTANCE.GETTARGET());
+//      SYSTEM.ERR.PRINTLN("PREDICT FEATURES = " + (STRING) INSTANCE.GETDATA());
+
+      if (classifier == null)
+        train();
+
+      Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+      return result;
+    }
+
+    public void train() {
+      Decoder.LOG(2, String.format("Word %s: training model from %d examples", 
+          sourceWord, examples.split("\\n").length));
+      
+      StringReader reader = new StringReader(examples);
+
+      // Constructs an instance with everything shoved into the data field
+      instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
+
+      ClassifierTrainer trainer = new MaxEntTrainer();
+      classifier = trainer.train(instances);
+      
+//      Decoder.LOG(1, String.format("%s: Trained a model for %s with %d outcomes", 
+//          name, sourceWord, pipes.getTargetAlphabet().size()));
+    }
+
+    /**
+     * Returns the number of distinct outcomes. Requires the model to have been trained!
+     * 
+     * @return
+     */
+    public int getNumOutcomes() {
+      if (classifier == null)
+        train();
+      return pipes.getTargetAlphabet().size();
+    }
+  }
\ No newline at end of file