You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:55 UTC

[09/13] incubator-joshua git commit: fiddling

fiddling


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/f0728fac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/f0728fac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/f0728fac

Branch: refs/heads/morph
Commit: f0728fac70bbacc7b33937f4915a091e4cfe4e36
Parents: b121fc2
Author: Matt Post <po...@cs.jhu.edu>
Authored: Fri Apr 22 00:11:54 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Fri Apr 22 00:11:54 2016 -0400

----------------------------------------------------------------------
 .../decoder/ff/morph/LexicalSharpener.java      | 159 +++++++++----------
 1 file changed, 79 insertions(+), 80 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/f0728fac/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index 701a72f..7982db4 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,7 +1,5 @@
 package joshua.decoder.ff.morph;
 
-import java.io.BufferedReader;
-
 /***
  * This feature function scores a rule application by predicting, for each target word aligned with
  * a source word, how likely the lexical translation is in context.
@@ -59,34 +57,56 @@ public class LexicalSharpener extends StatelessFF {
   public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, "LexicalSharpener", args, config);
 
-    if (parsedArgs.containsKey("model")) {
-      String modelFile = parsedArgs.get("model");
-      if (! new File(modelFile).exists()) {
-        if (parsedArgs.getOrDefault("training-data", null) != null) {
-          try {
-            trainAll(parsedArgs.get("training-data"));
-          } catch (FileNotFoundException e) {
-            // TODO Auto-generated catch block
-            e.printStackTrace();
-          }
-        } else {
-          System.err.println("* FATAL: no model and no training data.");
-          System.exit(1);
-        }
-      } else {
-        try {
-          loadClassifiers(modelFile);
-        } catch (IOException e) {
-          // TODO Auto-generated catch block
-          e.printStackTrace();
-        } catch (ClassNotFoundException e) {
-          // TODO Auto-generated catch block
-          e.printStackTrace();
-        }
+    if (parsedArgs.getOrDefault("training-data", null) != null) {
+      try {
+        trainAll(parsedArgs.get("training-data"));
+      } catch (FileNotFoundException e) {
+        System.err.println(String.format("* FATAL[LexicalSharpener]: can't load %s", parsedArgs.get("training-data")));
+        System.exit(1);
       }
     }
   }
   
+  /**
+   * Trains a maxent classifier from the provided training data, returning a Mallet model.
+   * 
+   * @param dataFile
+   * @return
+   * @throws FileNotFoundException
+   */
+  public void trainAll(String dataFile) throws FileNotFoundException {
+  
+    classifiers = new HashMap<Integer, Predictor>();
+
+    Decoder.LOG(1, "Reading " + dataFile);
+    LineReader lineReader = null;
+    try {
+      lineReader = new LineReader(dataFile, true);
+    } catch (IOException e) {
+      // TODO Auto-generated catch block
+      e.printStackTrace();
+    }
+  
+    String lastSourceWord = null;
+    String examples = "";
+    int linesRead = 0;
+    for (String line : lineReader) {
+      String sourceWord = line.substring(0, line.indexOf(' '));
+      if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+        classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+        //        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+        examples = "";
+      }
+  
+      examples += line + "\n";
+      lastSourceWord = sourceWord;
+      linesRead++;
+    }
+    classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+  
+    System.err.println(String.format("Read %d lines from training file", linesRead));
+  }
+
   public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
     ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
     classifiers = (HashMap<Integer,Predictor>) ois.readObject();
@@ -113,6 +133,8 @@ public class LexicalSharpener extends StatelessFF {
   @Override
   public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
       Sentence sentence, Accumulator acc) {
+    
+    System.err.println(String.format("RULE: %s",  rule));
         
     Map<Integer, List<Integer>> points = rule.getAlignmentMap();
     for (int t: points.keySet()) {
@@ -120,12 +142,14 @@ public class LexicalSharpener extends StatelessFF {
       if (source_indices.size() != 1)
         continue;
       
-      String targetWord = Vocabulary.word(rule.getEnglish()[t]);
+      int targetID = rule.getEnglish()[t];
+      String targetWord = Vocabulary.word(targetID);
       int s = i + source_indices.get(0);
       Token sourceToken = sentence.getTokens().get(s);
       String featureString = sourceToken.getAnnotationString().replace('|', ' ');
       
-      Classification result = predict(sourceToken.getWord(), featureString);
+      Classification result = predict(sourceToken.getWord(), targetID, featureString);
+      System.out.println("RESULT: " + result.getLabeling());
       if (result.bestLabelIsCorrect()) {
         acc.add(String.format("%s_match", name), 1);
       }
@@ -134,12 +158,12 @@ public class LexicalSharpener extends StatelessFF {
     return null;
   }
   
-  public Classification predict(int id, String featureString) {
-    String word = Vocabulary.word(id);
-    if (classifiers.containsKey(id)) {
-      Predictor predictor = classifiers.get(id);
+  public Classification predict(int sourceID, int targetID, String featureString) {
+    String word = Vocabulary.word(sourceID);
+    if (classifiers.containsKey(sourceID)) {
+      Predictor predictor = classifiers.get(sourceID);
       if (predictor != null)
-        return predictor.predict(featureString);
+        return predictor.predict(Vocabulary.word(targetID), featureString);
     }
 
     return null;
@@ -226,10 +250,10 @@ public class LexicalSharpener extends StatelessFF {
        * @param features the set of features
        * @return
        */
-    public Classification predict(String features) {
-      Instance instance = new Instance(features, null, null, null);
-      //    System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
-      //    System.err.println("PREDICT features = " + (String) instance.getData());
+    public Classification predict(String outcome, String features) {
+      Instance instance = new Instance(features, outcome, null, null);
+      System.err.println("PREDICT targetWord = " + (String) instance.getTarget());
+      System.err.println("PREDICT features = " + (String) instance.getData());
 
       if (classifier == null)
         train();
@@ -239,8 +263,8 @@ public class LexicalSharpener extends StatelessFF {
     }
 
     public void train() {
-      System.err.println(String.format("Word %s: training model", sourceWord));
-      System.err.println(String.format("  Examples: %s", examples));
+//      System.err.println(String.format("Word %s: training model", sourceWord));
+//      System.err.println(String.format("  Examples: %s", examples));
       
       StringReader reader = new StringReader(examples);
 
@@ -249,47 +273,21 @@ public class LexicalSharpener extends StatelessFF {
 
       ClassifierTrainer trainer = new MaxEntTrainer();
       classifier = trainer.train(instances);
-    }
-  }
-  
-  /**
-   * Trains a maxent classifier from the provided training data, returning a Mallet model.
-   * 
-   * @param dataFile
-   * @return
-   * @throws FileNotFoundException
-   */
-  public void trainAll(String dataFile) throws FileNotFoundException {
-
-    classifiers = new HashMap<Integer, Predictor>();
-
-    Decoder.VERBOSE = 1;
-    LineReader lineReader = null;
-    try {
-      lineReader = new LineReader(dataFile, true);
-    } catch (IOException e) {
-      // TODO Auto-generated catch block
-      e.printStackTrace();
+      
+      System.err.println(String.format("Trained a model for %s with %d outcomes", 
+          sourceWord, pipes.getTargetAlphabet().size()));
     }
 
-    String lastSourceWord = null;
-    String examples = "";
-    int linesRead = 0;
-    for (String line : lineReader) {
-      String sourceWord = line.substring(0, line.indexOf(' '));
-      if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
-        classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
-        //        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
-        examples = "";
-      }
-
-      examples += line + "\n";
-      lastSourceWord = sourceWord;
-      linesRead++;
+    /**
+     * Returns the number of distinct outcomes. Requires the model to have been trained!
+     * 
+     * @return
+     */
+    public int getNumOutcomes() {
+      if (classifier == null)
+        train();
+      return pipes.getTargetAlphabet().size();
     }
-    classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
-
-    System.err.println(String.format("Read %d lines from training file", linesRead));
   }
   
   public static void example(String[] args) throws IOException, ClassNotFoundException {
@@ -345,10 +343,11 @@ public class LexicalSharpener extends StatelessFF {
     Scanner stdin = new Scanner(System.in);
     while(stdin.hasNextLine()) {
       String line = stdin.nextLine();
-      String[] tokens = line.split(" ", 2);
+      String[] tokens = line.split(" ", 3);
       String sourceWord = tokens[0];
-      String features = tokens[1];
-      Classification result = ts.predict(Vocabulary.id(sourceWord), features);
+      String targetWord = tokens[1];
+      String features = tokens[2];
+      Classification result = ts.predict(Vocabulary.id(sourceWord), Vocabulary.id(targetWord), features);
       if (result != null)
         System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
       else