You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:54 UTC

[08/13] incubator-joshua git commit: Rewrote in less efficient way, now works!

Rewrote in less efficient way, now works!


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/b121fc20
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/b121fc20
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/b121fc20

Branch: refs/heads/morph
Commit: b121fc20c22dd6752ccf3e728ed8c23ca73eb9aa
Parents: 23306b4
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 17:59:50 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 17:59:50 2016 -0400

----------------------------------------------------------------------
 .../decoder/ff/morph/LexicalSharpener.java      | 267 ++++++++-----------
 1 file changed, 110 insertions(+), 157 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/b121fc20/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index a66f4a1..701a72f 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -55,42 +55,16 @@ import joshua.util.io.LineReader;
 
 public class LexicalSharpener extends StatelessFF {
 
-  private HashMap<Integer,Object> classifiers = null;
-  private SerialPipes pipes = null;
-  private InstanceList instances = null;
-  private LabelAlphabet labelAlphabet = null;
-  private Alphabet dataAlphabet = null;
-  
+  private HashMap<Integer,Predictor> classifiers = null;
   public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, "LexicalSharpener", args, config);
-    
-    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
-
-    dataAlphabet = new Alphabet();
-    labelAlphabet = new LabelAlphabet();
-    
-    // I don't know if this is needed
-    pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
-    // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
-    pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
-    // Validation
-//    pipeList.add(new PrintInputAndTarget());
-    
-    // name: english word
-    // data: features (FeatureVector)
-    // target: foreign inflection
-    // source: null
-
-    pipes = new SerialPipes(pipeList);
-    instances = new InstanceList(dataAlphabet, labelAlphabet);
-    instances.setPipe(pipes);
 
     if (parsedArgs.containsKey("model")) {
       String modelFile = parsedArgs.get("model");
       if (! new File(modelFile).exists()) {
         if (parsedArgs.getOrDefault("training-data", null) != null) {
           try {
-            train(parsedArgs.get("training-data"));
+            trainAll(parsedArgs.get("training-data"));
           } catch (FileNotFoundException e) {
             // TODO Auto-generated catch block
             e.printStackTrace();
@@ -113,49 +87,9 @@ public class LexicalSharpener extends StatelessFF {
     }
   }
   
-  /**
-   * Trains a maxent classifier from the provided training data, returning a Mallet model.
-   * 
-   * @param dataFile
-   * @return
-   * @throws FileNotFoundException
-   */
-  public void train(String dataFile) throws FileNotFoundException {
-
-    classifiers = new HashMap<Integer, Object>();
-    
-    Decoder.VERBOSE = 1;
-    LineReader lineReader = null;
-    try {
-      lineReader = new LineReader(dataFile, true);
-    } catch (IOException e) {
-      // TODO Auto-generated catch block
-      e.printStackTrace();
-    }
-
-    String lastOutcome = null;
-    String buffer = "";
-    int linesRead = 0;
-    for (String line : lineReader) {
-      String outcome = line.substring(0, line.indexOf(' '));
-      if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
-        classifiers.put(Vocabulary.id(lastOutcome), buffer);
-//        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
-        buffer = "";
-      }
-
-      buffer += line + "\n";
-      lastOutcome = outcome;
-      linesRead++;
-    }
-    classifiers.put(Vocabulary.id(lastOutcome),  buffer);
-    
-    System.err.println(String.format("Read %d lines from training file", linesRead));
-  }
-
   public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
     ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
-    classifiers = (HashMap<Integer,Object>) ois.readObject();
+    classifiers = (HashMap<Integer,Predictor>) ois.readObject();
     ois.close();
     
     System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
@@ -171,53 +105,6 @@ public class LexicalSharpener extends StatelessFF {
   }
   
   /**
-   * Returns a Classification object a list of features. Uses "which" to determine which classifier
-   * to use.
-   *   
-   * @param which the classifier to use
-   * @param features the set of features
-   * @return
-   */
-  public Classification predict(String which, String features) {
-    Instance instance = new Instance(features, which, which, which);
-//    System.err.println("PREDICT outcome = " + (String) instance.getTarget());
-//    System.err.println("PREDICT features = " + (String) instance.getData());
-    
-    Classifier classifier = getClassifier(which);
-    if (classifier != null) {
-      Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
-      return result;
-    }
-    
-    return null;
-  }
-  
-  public Classifier getClassifier(String which) {
-    int id = Vocabulary.id(which);
-    
-    if (classifiers.containsKey(id)) {
-      if (classifiers.get(id) instanceof String) {
-        StringReader reader = new StringReader((String)classifiers.get(id));
-
-        System.err.println("training new classifier from string for " + which);
-        System.err.println(String.format("training string is: '%s'", (String)classifiers.get(id)));
-        
-        // Constructs an instance with everything shoved into the data field
-        instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
-
-        ClassifierTrainer trainer = new MaxEntTrainer();
-        Classifier classifier = trainer.train(instances);
-        
-        classifiers.put(id, classifier);
-      }
-      
-      return (Classifier) classifiers.get(id);
-    }
-
-    return null;
-  }
-
-  /**
    * Compute features. This works by walking over the target side phrase pieces, looking for every
    * word with a single source-aligned word. We then throw the annotations from that source word
    * into our prediction model to learn how much it likes the chosen word. Presumably the source-
@@ -238,7 +125,7 @@ public class LexicalSharpener extends StatelessFF {
       Token sourceToken = sentence.getTokens().get(s);
       String featureString = sourceToken.getAnnotationString().replace('|', ' ');
       
-      Classification result = predict(targetWord, featureString);
+      Classification result = predict(sourceToken.getWord(), featureString);
       if (result.bestLabelIsCorrect()) {
         acc.add(String.format("%s_match", name), 1);
       }
@@ -247,6 +134,17 @@ public class LexicalSharpener extends StatelessFF {
     return null;
   }
   
+  public Classification predict(int id, String featureString) {
+    String word = Vocabulary.word(id);
+    if (classifiers.containsKey(id)) {
+      Predictor predictor = classifiers.get(id);
+      if (predictor != null)
+        return predictor.predict(featureString);
+    }
+
+    return null;
+  }
+  
   /**
    * Returns an array parallel to the source words array indicating, for each index, the absolute
    * position of that word into the source sentence. For example, for the rule with source side
@@ -291,54 +189,109 @@ public class LexicalSharpener extends StatelessFF {
     return anchoredSource;
   }
   
-  public static class CustomLineProcessor extends Pipe {
+  public class Predictor {
+    
+    private SerialPipes pipes = null;
+    private InstanceList instances = null;
+    private String sourceWord = null;
+    private String examples = null;
+    private Classifier classifier = null;
+    
+    public Predictor(String word, String examples) {
+      this.sourceWord = word;
+      this.examples = examples;
+      ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
+
+      // I don't know if this is needed
+      pipeList.add(new Target2Label());
+      // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+      pipeList.add(new SvmLight2FeatureVectorAndLabel());
+      // Validation
+//      pipeList.add(new PrintInputAndTarget());
+      
+      // name: english word
+      // data: features (FeatureVector)
+      // target: foreign inflection
+      // source: null
 
-    private static final long serialVersionUID = 1L;
+      pipes = new SerialPipes(pipeList);
+      instances = new InstanceList(pipes);
+    }
+
+    /**
+       * Returns a Classification object a list of features. Uses "which" to determine which classifier
+       * to use.
+       *   
+       * @param which the classifier to use
+       * @param features the set of features
+       * @return
+       */
+    public Classification predict(String features) {
+      Instance instance = new Instance(features, null, null, null);
+      //    System.err.println("PREDICT sourceWord = " + (String) instance.getTarget());
+      //    System.err.println("PREDICT features = " + (String) instance.getData());
+
+      if (classifier == null)
+        train();
 
-    public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
-      super(data, label);
+      Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+      return result;
     }
-    
-    @Override 
-    public Instance pipe(Instance carrier) {
-      // we expect the data for each instance to be
-      // a line from the SVMLight format text file    
-      String dataStr = (String)carrier.getData();
 
-      String[] tokens = dataStr.split("\\s+");
+    public void train() {
+      System.err.println(String.format("Word %s: training model", sourceWord));
+      System.err.println(String.format("  Examples: %s", examples));
+      
+      StringReader reader = new StringReader(examples);
 
-//      carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+      // Constructs an instance with everything shoved into the data field
+      instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(.*)", 2, -1, 1));
 
-      // the rest are feature-value pairs
-      ArrayList<Integer> indices = new ArrayList<Integer>();
-      ArrayList<Double> values = new ArrayList<Double>();
-      for (int termIndex = 0; termIndex < tokens.length; termIndex++) {
-        if (!tokens[termIndex].equals("")) {
-          String[] s = tokens[termIndex].split(":");
-          if (s.length != 2) {
-            throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
-          }
-          String feature = s[0];
-          int index = getDataAlphabet().lookupIndex(feature, true);
-          indices.add(index);
-          values.add(Double.parseDouble(s[1]));
-        }
-      }
+      ClassifierTrainer trainer = new MaxEntTrainer();
+      classifier = trainer.train(instances);
+    }
+  }
+  
+  /**
+   * Trains a maxent classifier from the provided training data, returning a Mallet model.
+   * 
+   * @param dataFile
+   * @return
+   * @throws FileNotFoundException
+   */
+  public void trainAll(String dataFile) throws FileNotFoundException {
+
+    classifiers = new HashMap<Integer, Predictor>();
 
-      assert(indices.size() == values.size());
-      int[] indicesArr = new int[indices.size()];
-      double[] valuesArr = new double[values.size()];
-      for (int i = 0; i < indicesArr.length; i++) {
-        indicesArr[i] = indices.get(i);
-        valuesArr[i] = values.get(i);
+    Decoder.VERBOSE = 1;
+    LineReader lineReader = null;
+    try {
+      lineReader = new LineReader(dataFile, true);
+    } catch (IOException e) {
+      // TODO Auto-generated catch block
+      e.printStackTrace();
+    }
+
+    String lastSourceWord = null;
+    String examples = "";
+    int linesRead = 0;
+    for (String line : lineReader) {
+      String sourceWord = line.substring(0, line.indexOf(' '));
+      if (lastSourceWord != null && ! sourceWord.equals(lastSourceWord)) {
+        classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
+        //        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+        examples = "";
       }
 
-      cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
-      carrier.setData(fv);
-      return carrier;
+      examples += line + "\n";
+      lastSourceWord = sourceWord;
+      linesRead++;
     }
-  }
+    classifiers.put(Vocabulary.id(lastSourceWord), new Predictor(lastSourceWord, examples));
 
+    System.err.println(String.format("Read %d lines from training file", linesRead));
+  }
+  
   public static void example(String[] args) throws IOException, ClassNotFoundException {
 
     ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
@@ -348,7 +301,7 @@ public class LexicalSharpener extends StatelessFF {
     
     pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
     // Basically, SvmLight but with a custom (fixed) alphabet)
-    pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
+    pipeList.add(new SvmLight2FeatureVectorAndLabel());
 
     FileReader reader1 = new FileReader("data.1");
     FileReader reader2 = new FileReader("data.2");
@@ -377,7 +330,7 @@ public class LexicalSharpener extends StatelessFF {
       String dataFile = args[0];
 
       System.err.println("Training model from file " + dataFile);
-      ts.train(dataFile);
+      ts.trainAll(dataFile);
     
 //      if (args.length > 1)
 //        modelFile = args[1];
@@ -395,7 +348,7 @@ public class LexicalSharpener extends StatelessFF {
       String[] tokens = line.split(" ", 2);
       String sourceWord = tokens[0];
       String features = tokens[1];
-      Classification result = ts.predict(sourceWord, features);
+      Classification result = ts.predict(Vocabulary.id(sourceWord), features);
       if (result != null)
         System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
       else