You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:49 UTC

[03/13] incubator-joshua git commit: Added full-file training, start of feature function

Added full-file training, start of feature function


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/47f1af58
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/47f1af58
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/47f1af58

Branch: refs/heads/morph
Commit: 47f1af588a8878963112b11ef29e4180e9d82869
Parents: dca7a5d
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 09:15:09 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 09:15:09 2016 -0400

----------------------------------------------------------------------
 .../decoder/ff/morph/InflectionPredictor.java   | 193 ++++++++++++++++---
 src/joshua/decoder/segment_file/Token.java      |  12 +-
 2 files changed, 178 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/47f1af58/src/joshua/decoder/ff/morph/InflectionPredictor.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/InflectionPredictor.java b/src/joshua/decoder/ff/morph/InflectionPredictor.java
index 5282497..f4a4310 100644
--- a/src/joshua/decoder/ff/morph/InflectionPredictor.java
+++ b/src/joshua/decoder/ff/morph/InflectionPredictor.java
@@ -1,25 +1,40 @@
 package joshua.decoder.ff.morph;
 
-/*
+/***
+ * This feature function scores a rule application by predicting, for each target word aligned with
+ * a source word, how likely the lexical translation is in context.
+ * 
+ * The feature function can be provided with a trained model or a raw training file which it will
+ * then train prior to decoding.
+ * 
  * Format of training file:
  * 
  * source_word target_word feature:value feature:value feature:value ...
  * 
  * Invocation:
  * 
- * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.InflectionPredictor /path/to/training/data 
+ * java -cp /Users/post/code/joshua/lib/mallet-2.0.7.jar:/Users/post/code/joshua/lib/trove4j-2.0.2.jar:$JOSHUA/class joshua.decoder.ff.morph.LexicalSharpener /path/to/training/data 
  */
 
 import java.io.File;
+import java.io.FileInputStream;
 import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
 import java.io.FileReader;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
 
 import cc.mallet.classify.*;
 import cc.mallet.pipe.*;
 import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Instance;
 import cc.mallet.types.InstanceList;
+import joshua.corpus.Vocabulary;
 import joshua.decoder.JoshuaConfiguration;
 import joshua.decoder.chart_parser.SourcePath;
 import joshua.decoder.ff.FeatureVector;
@@ -28,13 +43,31 @@ import joshua.decoder.ff.state_maintenance.DPState;
 import joshua.decoder.ff.tm.Rule;
 import joshua.decoder.hypergraph.HGNode;
 import joshua.decoder.segment_file.Sentence;
+import joshua.decoder.segment_file.Token;
 
 public class InflectionPredictor extends StatelessFF {
 
   private Classifier classifier = null;
+  private SerialPipes pipes = null;
   
   public InflectionPredictor(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
-    super(weights, "InfectionPredictor", args, config);
+    super(weights, "LexicalSharpener", args, config);
+    
+    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+    // I don't know if this is needed
+    pipeList.add(new Target2Label());
+    // Convert SVM-light format to sparse feature vector
+    pipeList.add(new SvmLight2FeatureVectorAndLabel());
+    // Validation
+//    pipeList.add(new PrintInputAndTarget());
+    
+    // name: english word
+    // data: features (FeatureVector)
+    // target: foreign inflection
+    // source: null
+
+    pipes = new SerialPipes(pipeList);
 
     if (parsedArgs.containsKey("model")) {
       String modelFile = parsedArgs.get("model");
@@ -51,28 +84,30 @@ public class InflectionPredictor extends StatelessFF {
           System.exit(1);
         }
       } else {
-        // TODO: load the model
+        try {
+          loadClassifier(modelFile);
+        } catch (IOException e) {
+          // TODO Auto-generated catch block
+          e.printStackTrace();
+        } catch (ClassNotFoundException e) {
+          // TODO Auto-generated catch block
+          e.printStackTrace();
+        }
       }
     }
   }
   
+  /**
+   * Trains a maxent classifier from the provided training data, returning a Mallet model.
+   * 
+   * @param dataFile
+   * @return
+   * @throws FileNotFoundException
+   */
   public Classifier train(String dataFile) throws FileNotFoundException {
-    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
-    // I don't know if this is needed
-    pipeList.add(new Target2Label());
-    // Convert SVM-light format to sparse feature vector
-    pipeList.add(new SvmLight2FeatureVectorAndLabel());
-    // Validation
-//    pipeList.add(new PrintInputAndTarget());
-    
-    // name: english word
-    // data: features (FeatureVector)
-    // target: foreign inflection
-    // source: null
 
     // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
-    InstanceList instances = new InstanceList(new SerialPipes(pipeList));
+    InstanceList instances = new InstanceList(pipes);
     instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
         "(\\w+)\\s+(.*)",
         2, -1, 1));
@@ -82,22 +117,130 @@ public class InflectionPredictor extends StatelessFF {
     
     return classifier;
   }
-    
+
+  public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+    ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
+    classifier = (Classifier) ois.readObject();
+  }
+
+  public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+    ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
+    oos.writeObject(classifier);
+    oos.close();
+  }
+  
+  public Classification predict(String outcome, String features) {
+    Instance instance = new Instance(features, null, null, null);
+    System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+    System.err.println("PREDICT features = " + (String) instance.getData());
+    Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+   
+    return result;
+  }
+
+  /**
+   * Compute features. This works by walking over the target side phrase pieces, looking for every
+   * word with a single source-aligned word. We then throw the annotations from that source word
+   * into our prediction model to learn how much it likes the chosen word. Presumably the source-
+   * language annotations have contextual features, so this effectively chooses the words in context.
+   */
   @Override
   public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
       Sentence sentence, Accumulator acc) {
+        
+    Map<Integer, List<Integer>> points = rule.getAlignmentMap();
+    for (int t: points.keySet()) {
+      List<Integer> source_indices = points.get(t);
+      if (source_indices.size() != 1)
+        continue;
+      
+      String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+      int s = i + source_indices.get(0);
+      Token sourceToken = sentence.getTokens().get(s);
+      String featureString = sourceToken.getAnnotationString().replace('|', ' ');
+      
+      Classification result = predict(targetWord, featureString);
+      if (result.bestLabelIsCorrect()) {
+        acc.add(String.format("%s_match", name), 1);
+      }
+    }
     
     return null;
   }
   
-  public static void main(String[] args) throws FileNotFoundException {
-    InflectionPredictor ip = new InflectionPredictor(null, args, null);
+  /**
+   * Returns an array parallel to the source words array indicating, for each index, the absolute
+   * position of that word into the source sentence. For example, for the rule with source side
+   * 
+   * [ 17, 142, -14, 9 ]
+   * 
+   * and source sentence
+   * 
+   * [ 17, 18, 142, 1, 1, 9, 8 ]
+   * 
+   * it will return
+   * 
+   * [ 0, 2, -14, 5 ]
+   * 
+   * which indicates that the first, second, and fourth words of the rule are anchored to the
+   * first, third, and sixth words of the input sentence. 
+   * 
+   * @param rule
+   * @param tailNodes
+   * @param start
+   * @return a list of alignment points anchored to the source sentence
+   */
+  public int[] anchorRuleSourceToSentence(Rule rule, List<HGNode> tailNodes, int start) {
+    int[] source = rule.getFrench();
+
+    // Map the source words in the rule to absolute positions in the sentence
+    int[] anchoredSource = source.clone();
     
-    String dataFile = "/Users/post/Desktop/amazon16/model";
-    if (args.length > 0)
-      dataFile = args[0];
+    int sourceIndex = start;
+    int tailNodeIndex = 0;
+    for (int i = 0; i < source.length; i++) {
+      if (source[i] < 0) { // nonterminal
+        anchoredSource[i] = source[i];
+        sourceIndex = tailNodes.get(tailNodeIndex).j;
+        tailNodeIndex++;
+      } else { // terminal
+        anchoredSource[i] = sourceIndex;
+        sourceIndex++;
+      }
+    }
     
-    ip.train(dataFile);
+    return anchoredSource;
   }
 
+  public static void main(String[] args) throws IOException, ClassNotFoundException {
+    InflectionPredictor ts = new InflectionPredictor(null, args, null);
+    
+    String modelFile = "model";
+
+    if (args.length > 0) {
+      String dataFile = args[0];
+
+      System.err.println("Training model from file " + dataFile);
+      ts.train(dataFile);
+    
+      if (args.length > 1)
+        modelFile = args[1];
+      
+      System.err.println("Writing model to file " + modelFile); 
+      ts.saveClassifier(modelFile);
+    } else {
+      System.err.println("Loading model from file " + modelFile);
+      ts.loadClassifier(modelFile);
+    }
+    
+    Scanner stdin = new Scanner(System.in);
+    while(stdin.hasNextLine()) {
+      String line = stdin.nextLine();
+      String[] tokens = line.split(" ", 2);
+      String outcome = tokens[0];
+      String features = tokens[1];
+      Classification result = ts.predict(outcome, features);
+      System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/47f1af58/src/joshua/decoder/segment_file/Token.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/segment_file/Token.java b/src/joshua/decoder/segment_file/Token.java
index 12e2b68..7969294 100644
--- a/src/joshua/decoder/segment_file/Token.java
+++ b/src/joshua/decoder/segment_file/Token.java
@@ -36,6 +36,7 @@ public class Token {
   private int tokenID;
 
   private HashMap<String,String> annotations = null;
+  private String annotationString;
 
   /**
    * Constructor : Creates a Token object from a raw word
@@ -69,9 +70,9 @@ public class Token {
     if (tag.find()) {
       // Annotation match found
       token = tag.group(1);
-      String tagStr = tag.group(2);
+      annotationString = tag.group(2);
 
-      for (String annotation: tagStr.split(";")) {
+      for (String annotation: annotationString.split(";")) {
         int where = annotation.indexOf("=");
         if (where != -1) {
           annotations.put(annotation.substring(0, where), annotation.substring(where + 1));
@@ -121,4 +122,11 @@ public class Token {
     
     return null;
   }
+  
+  /**
+   * Returns the raw annotation string
+   */
+  public String getAnnotationString() {
+    return annotationString;
+  }
 }
\ No newline at end of file