You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@joshua.apache.org by mj...@apache.org on 2016/04/22 06:17:52 UTC

[06/13] incubator-joshua git commit: Will now train a single model, barfs with Alphabet issues on second, ARGH

Will now train a single model, barfs with Alphabet issues on second, ARGH


Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/a055c3f7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/a055c3f7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/a055c3f7

Branch: refs/heads/morph
Commit: a055c3f7edb84234ecfbdf56d4be92a60fc5fc0e
Parents: d17d08a
Author: Matt Post <po...@cs.jhu.edu>
Authored: Thu Apr 21 15:06:35 2016 -0400
Committer: Matt Post <po...@cs.jhu.edu>
Committed: Thu Apr 21 15:06:57 2016 -0400

----------------------------------------------------------------------
 .../decoder/ff/morph/LexicalSharpener.java      | 77 +++++++++-----------
 1 file changed, 35 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/a055c3f7/src/joshua/decoder/ff/morph/LexicalSharpener.java
----------------------------------------------------------------------
diff --git a/src/joshua/decoder/ff/morph/LexicalSharpener.java b/src/joshua/decoder/ff/morph/LexicalSharpener.java
index b4b455d..fab71d9 100644
--- a/src/joshua/decoder/ff/morph/LexicalSharpener.java
+++ b/src/joshua/decoder/ff/morph/LexicalSharpener.java
@@ -69,7 +69,7 @@ public class LexicalSharpener extends StatelessFF {
     labelAlphabet = new LabelAlphabet();
     
     // I don't know if this is needed
-//    pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
+    pipeList.add(new Target2Label(dataAlphabet, labelAlphabet));
     // Convert custom lines to Instance objects (svmLight2FeatureVectorAndLabel not versatile enough)
     pipeList.add(new CustomLineProcessor(dataAlphabet, labelAlphabet));
     // Validation
@@ -155,6 +155,12 @@ public class LexicalSharpener extends StatelessFF {
   public void loadClassifiers(String modelFile) throws ClassNotFoundException, IOException {
     ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
     classifiers = (HashMap<Integer,Object>) ois.readObject();
+    ois.close();
+    
+    System.err.println(String.format("Loaded model with %d keys", classifiers.keySet().size()));
+    for (int key: classifiers.keySet()) {
+      System.err.println("  " + key);
+    }
   }
 
   public void saveClassifiers(String modelFile) throws FileNotFoundException, IOException {
@@ -172,14 +178,17 @@ public class LexicalSharpener extends StatelessFF {
    * @return
    */
   public Classification predict(String which, String features) {
-    Instance instance = new Instance(features, which, null, null);
+    Instance instance = new Instance(features, which, which, which);
 //    System.err.println("PREDICT outcome = " + (String) instance.getTarget());
 //    System.err.println("PREDICT features = " + (String) instance.getData());
     
     Classifier classifier = getClassifier(which);
-    Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
-   
-    return result;
+    if (classifier != null) {
+      Classification result = (Classification) classifier.classify(pipes.instanceFrom(instance));
+      return result;
+    }
+    
+    return null;
   }
   
   public Classifier getClassifier(String which) {
@@ -190,20 +199,10 @@ public class LexicalSharpener extends StatelessFF {
         StringReader reader = new StringReader((String)classifiers.get(id));
 
         System.err.println("training new classifier from string for " + which);
+        System.err.println(String.format("training string is: '%s'", (String)classifiers.get(id)));
         
-        BufferedReader t = new BufferedReader(reader);
-        String line;
-        try {
-          while ((line = t.readLine()) != null) {
-            System.out.println("  LINE: " + line);
-          }
-        } catch (IOException e) {
-          // TODO Auto-generated catch block
-          e.printStackTrace();
-        }
-        
-        // Remove the first field (Mallet's "name" field), leave the rest for SVM-light conversion
-        instances.addThruPipe(new CsvIterator(reader, "(\\w+)\\s+(.*)", 2, -1, -1));
+        // Constructs an instance with everything shoved into the data field
+        instances.addThruPipe(new CsvIterator(reader, "(\\S+)\\s+(\\S+)\\s+(.*)", 3, 2, 1));
 
         ClassifierTrainer trainer = new MaxEntTrainer();
         Classifier classifier = trainer.train(instances);
@@ -295,10 +294,6 @@ public class LexicalSharpener extends StatelessFF {
 
     private static final long serialVersionUID = 1L;
 
-    public CustomLineProcessor() {
-      super (new Alphabet(), new LabelAlphabet());
-    }
-    
     public CustomLineProcessor(Alphabet data, LabelAlphabet label) {
       super(data, label);
     }
@@ -311,12 +306,12 @@ public class LexicalSharpener extends StatelessFF {
 
       String[] tokens = dataStr.split("\\s+");
 
-      carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
+//      carrier.setTarget(((LabelAlphabet)getTargetAlphabet()).lookupLabel(tokens[1], true));
 
       // the rest are feature-value pairs
       ArrayList<Integer> indices = new ArrayList<Integer>();
       ArrayList<Double> values = new ArrayList<Double>();
-      for (int termIndex = 1; termIndex < tokens.length; termIndex++) {
+      for (int termIndex = 0; termIndex < tokens.length; termIndex++) {
         if (!tokens[termIndex].equals("")) {
           String[] s = tokens[termIndex].split(":");
           if (s.length != 2) {
@@ -324,13 +319,8 @@ public class LexicalSharpener extends StatelessFF {
           }
           String feature = s[0];
           int index = getDataAlphabet().lookupIndex(feature, true);
-
-          // index may be -1 if growth of the
-          // data alphabet is stopped
-          if (index >= 0) {
-            indices.add(index);
-            values.add(Double.parseDouble(s[1]));
-          }
+          indices.add(index);
+          values.add(Double.parseDouble(s[1]));
         }
       }
 
@@ -359,24 +349,27 @@ public class LexicalSharpener extends StatelessFF {
       System.err.println("Training model from file " + dataFile);
       ts.train(dataFile);
     
-      if (args.length > 1)
-        modelFile = args[1];
-      
-      System.err.println("Writing model to file " + modelFile); 
-      ts.saveClassifiers(modelFile);
-    } else {
-      System.err.println("Loading model from file " + modelFile);
-      ts.loadClassifiers(modelFile);
+//      if (args.length > 1)
+//        modelFile = args[1];
+//      
+//      System.err.println("Writing model to file " + modelFile); 
+//      ts.saveClassifiers(modelFile);
+//    } else {
+//      System.err.println("Loading model from file " + modelFile);
+//      ts.loadClassifiers(modelFile);
     }
     
     Scanner stdin = new Scanner(System.in);
     while(stdin.hasNextLine()) {
       String line = stdin.nextLine();
       String[] tokens = line.split(" ", 2);
-      String outcome = tokens[0];
+      String sourceWord = tokens[0];
       String features = tokens[1];
-      Classification result = ts.predict(outcome, features);
-      System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+      Classification result = ts.predict(sourceWord, features);
+      if (result != null)
+        System.out.println(String.format("%s %f", result.getLabelVector().getBestLabel(), result.getLabelVector().getBestValue()));
+      else 
+        System.out.println("i got nothing");
     }
   }
 }