You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:51 UTC

[05/13] incubator-joshua git commit: Version with custom line processor; still not working (alphabet problems)

Version with custom line processor; still not working (alphabet problems)


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/d17d08af
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/d17d08af
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/d17d08af

Branch: refs/heads/morph
Commit: d17d08afb449ce86c5f0d9ca5b2ccdc5ea777fda
Parents: 0ef5d3e
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 14:11:37 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 14:11:37 2016 -0400

----------------------------------------------------------------------
 .../decoder/ff/morph/LexicalSharpener.java      | 192 ++++++++++++++++---
 1 file changed, 164 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/d17d08af/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index edf4390..b4b455d 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -1,5 +1,7 @@
 package joshua.decoder.ff.morph;
 
+import java.io.BufferedReader;
+
 /***
  * This feature function scores a rule application by predicting, for each target word aligned with
  * a source word, how likely the lexical translation is in context.
@@ -20,11 +22,12 @@ import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileNotFoundException;
 import java.io.FileOutputStream;
-import java.io.FileReader;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.io.StringReader;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Scanner;
@@ -32,9 +35,12 @@ import java.util.Scanner;
 import cc.mallet.classify.*;
 import cc.mallet.pipe.*;
 import cc.mallet.pipe.iterator.CsvIterator;
+import cc.mallet.types.Alphabet;
 import cc.mallet.types.Instance;
 import cc.mallet.types.InstanceList;
+import cc.mallet.types.LabelAlphabet;
 import joshua.corpus.Vocabulary;
+import joshua.decoder.Decoder;
 import joshua.decoder.JoshuaConfiguration;
 import joshua.decoder.chart_parser.SourcePath;
 import joshua.decoder.ff.FeatureVector;
@@ -44,21 +50,28 @@ import joshua.decoder.ff.tm.Rule;
 import joshua.decoder.hypergraph.HGNode;
 import joshua.decoder.segment_file.Sentence;
 import joshua.decoder.segment_file.Token;
+import joshua.util.io.LineReader;
 
 public class LexicalSharpener extends StatelessFF {
 
-  private Classifier classifier = null;
+  private HashMap<Integer,Object> classifiers = null;
   private SerialPipes pipes = null;
+  private InstanceList instances = null;
+  private LabelAlphabet labelAlphabet = null;
+  private Alphabet dataAlphabet = null;
   
   public LexicalSharpener(final FeatureVector weights, String[] args, JoshuaConfiguration config) {
     super(weights, "LexicalSharpener", args, config);
     
     ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
 
+    dataAlphabet = new Alphabet();
+    labelAlphabet = new LabelAlphabet();
+    
     // I don't know if this is needed
-    pipeList.add(new Target2Label());
-    // Convert SVM-light format to sparse feature vector
-    pipeList.add(new SvmLight2FeatureVectorAndLabel());
+//    pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+    // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
+    pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
     // Validation
 //    pipeList.add(new PrintInputAndTarget());
     
@@ -68,13 +81,15 @@ public class LexicalSharpener extends StatelessFF {
     // source: null
 
     pipes = new SerialPipes(pipeList);
+    instances = new InstanceList(dataAlphabet, labelAlphabet);
+    instances.setPipe(pipes);
 
     if (parsedArgs.containsKey("model")) {
       String modelFile = parsedArgs.get("model");
       if (! new File(modelFile).exists()) {
         if (parsedArgs.getOrDefault("training-data", null) != null) {
           try {
-            classifier = train(parsedArgs.get("training-data"));
+            train(parsedArgs.get("training-data"));
           } catch (FileNotFoundException e) {
             // TODO Auto-generated catch block
             e.printStackTrace();
@@ -85,7 +100,7 @@ public class LexicalSharpener extends StatelessFF {
         }
       } else {
         try {
-          loadClassifier(modelFile);
+          loadClassifiers(modelFile);
         } catch (IOException e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
@@ -104,39 +119,103 @@ public class LexicalSharpener extends StatelessFF {
    * @return
    * @throws FileNotFoundException
    */
-  public Classifier train(String dataFile) throws FileNotFoundException {
-
-    // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
-    InstanceList instances = new InstanceList(pipes);
-    instances.addThruPipe(new CsvIterator(new FileReader(dataFile),
-        "(\\w+)\\s+(.*)",
-        2, -1, 1));
-          
-    ClassifierTrainer trainer = new MaxEntTrainer();
-    Classifier classifier = trainer.train(instances);
+  public void train(String dataFile) throws FileNotFoundException {
+
+    classifiers = new HashMap<Integer, Object>();
     
-    return classifier;
+    Decoder.VERBOSE = 1;
+    LineReader lineReader = null;
+    try {
+      lineReader = new LineReader(dataFile, true);
+    } catch (IOException e) {
+      // TODO Auto-generated catch block
+      e.printStackTrace();
+    }
+
+    String lastOutcome = null;
+    String buffer = "";
+    int linesRead = 0;
+    for (String line : lineReader) {
+      String outcome = line.substring(0, line.indexOf(' '));
+      if (lastOutcome != null && ! outcome.equals(lastOutcome)) {
+        classifiers.put(Vocabulary.id(lastOutcome), buffer);
+//        System.err.println(String.format("WORD %s:\n%s\n", lastOutcome, buffer));
+        buffer = "";
+      }
+
+      buffer += line + "\n";
+      lastOutcome = outcome;
+      linesRead++;
+    }
+    classifiers.put(Vocabulary.id(lastOutcome),  buffer);
+    
+    System.err.println(String.format("Read %d lines from training file", linesRead));
   }
 
-  public void loadClassifier(String modelFile) throws ClassNotFoundException, IOException {
+  public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
     ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
-    classifier = (Classifier) ois.readObject();
+    classifiers = (HashMap<Integer,Object>) ois.readObject();
   }
 
-  public void saveClassifier(String modelFile) throws FileNotFoundException, IOException {
+  public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
     ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile));
-    oos.writeObject(classifier);
+    oos.writeObject(classifiers);
     oos.close();
   }
   
-  public Classification predict(String outcome, String features) {
-    Instance instance = new Instance(features, null, null, null);
-    System.err.println("PREDICT outcome = " + (String) instance.getTarget());
-    System.err.println("PREDICT features = " + (String) instance.getData());
+  /**
+   * Returns a Classification object a list of features. Uses "which" to determine which classifier
+   * to use.
+   *   
+   * @param which the classifier to use
+   * @param features the set of features
+   * @return
+   */
+  public Classification predict(String which, String features) {
+    Instance instance = new Instance(features, which, null, null);
+//    System.err.println("PREDICT outcome = " + (String) instance.getTarget());
+//    System.err.println("PREDICT features = " + (String) instance.getData());
+    
+    Classifier classifier = getClassifier(which);
     Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
    
     return result;
   }
+  
+  public Classifier getClassifier(String which) {
+    int id = Vocabulary.id(which);
+    
+    if (classifiers.containsKey(id)) {
+      if (classifiers.get(id) instanceof String) {
+        StringReader reader = new StringReader((String)classifiers.get(id));
+
+        System.err.println("training new classifier from string for " + which);
+        
+        BufferedReader t = new BufferedReader(reader);
+        String line;
+        try {
+          while ((line = t.readLine()) != null) {
+            System.out.println("  LINE: " + line);
+          }
+        } catch (IOException e) {
+          // TODO Auto-generated catch block
+          e.printStackTrace();
+        }
+        
+        // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
+        instances.addThruPipe(new CsvIterator(reader, "(\\w+)\\s+(.*)", 2, -1, -1));
+
+        ClassifierTrainer trainer = new MaxEntTrainer();
+        Classifier classifier = trainer.train(instances);
+        
+        classifiers.put(id, classifier);
+      }
+      
+      return (Classifier) classifiers.get(id);
+    }
+
+    return null;
+  }
 
   /**
    * Compute features. This works by walking over the target side phrase pieces, looking for every
@@ -211,6 +290,63 @@ public class LexicalSharpener extends StatelessFF {
     
     return anchoredSource;
   }
+  
+  public class CustomLineProcessor extends Pipe {
+
+    private static final long serialVersionUID = 1L;
+
+    public CustomLineProcessor() {
+      super (new Alphabet(), new LabelAlphabet());
+    }
+    
+    public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
+      super(data, label);
+    }
+    
+    @Override 
+    public Instance pipe(Instance carrier) {
+      // we expect the data for each instance to be
+      // a line from the SVMLight format text file    
+      String dataStr = (String)carrier.getData();
+
+      String[] tokens = dataStr.split("\\s+");
+
+      carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+
+      // the rest are feature-value pairs
+      ArrayList<Integer> indices = new ArrayList<Integer>();
+      ArrayList<Double> values = new ArrayList<Double>();
+      for (int termIndex = 1; termIndex < tokens.length; termIndex++) {
+        if (!tokens[termIndex].equals("")) {
+          String[] s = tokens[termIndex].split(":");
+          if (s.length != 2) {
+            throw new RuntimeException("invalid format: " + tokens[termIndex] + " (should be feature:value)");
+          }
+          String feature = s[0];
+          int index = getDataAlphabet().lookupIndex(feature, true);
+
+          // index may be -1 if growth of the
+          // data alphabet is stopped
+          if (index >= 0) {
+            indices.add(index);
+            values.add(Double.parseDouble(s[1]));
+          }
+        }
+      }
+
+      assert(indices.size() == values.size());
+      int[] indicesArr = new int[indices.size()];
+      double[] valuesArr = new double[values.size()];
+      for (int i = 0; i < indicesArr.length; i++) {
+        indicesArr[i] = indices.get(i);
+        valuesArr[i] = values.get(i);
+      }
+
+      cc.mallet.types.FeatureVector fv = new cc.mallet.types.FeatureVector(getDataAlphabet(), indicesArr, valuesArr);
+      carrier.setData(fv);
+      return carrier;
+    }
+  }
 
   public static void main(String[] args) throws IOException, ClassNotFoundException {
     LexicalSharpener ts = new LexicalSharpener(null, args, null);
@@ -227,10 +363,10 @@ public class LexicalSharpener extends StatelessFF {
         modelFile = args[1];
       
       System.err.println("Writing model to file " + modelFile); 
-      ts.saveClassifier(modelFile);
+      ts.saveClassifiers(modelFile);
     } else {
       System.err.println("Loading model from file " + modelFile);
-      ts.loadClassifier(modelFile);
+      ts.loadClassifiers(modelFile);
     }
     
     Scanner stdin = new Scanner(System.in);