You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by co...@apache.org on 2012/05/08 22:43:35 UTC

svn commit: r1335756 - in /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools: cmdline/postag/POSTaggerCrossValidatorTool.java cmdline/postag/POSTaggerTrainerTool.java postag/POSTaggerCrossValidator.java postag/POSTaggerFactory.java

Author: colen
Date: Tue May  8 20:43:34 2012
New Revision: 1335756

URL: http://svn.apache.org/viewvc?rev=1335756&view=rev
Log:
OPENNLP-508 Included the mutable dictionary capability to the POSTagger trainer and cross validator tools

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java Tue May  8 20:43:34 2012
@@ -92,8 +92,8 @@ public final class POSTaggerCrossValidat
       }
 
       validator = new POSTaggerCrossValidator(factory.getLang(), mlParams,
-          tagdict, params.getNgram(), params.getFactory(),
-          missclassifiedListener, reportListener);
+          tagdict, params.getNgram(), params.getTagDictCutoff(),
+          params.getFactory(), missclassifiedListener, reportListener);
       
       validator.evaluate(sampleStream, params.getFolds());
     } catch (IOException e) {

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java Tue May  8 20:43:34 2012
@@ -28,6 +28,7 @@ import opennlp.tools.cmdline.TerminateTo
 import opennlp.tools.cmdline.params.TrainingToolParams;
 import opennlp.tools.cmdline.postag.POSTaggerTrainerTool.TrainerToolParams;
 import opennlp.tools.dictionary.Dictionary;
+import opennlp.tools.postag.MutableTagDictionary;
 import opennlp.tools.postag.POSDictionary;
 import opennlp.tools.postag.POSModel;
 import opennlp.tools.postag.POSSample;
@@ -103,6 +104,27 @@ public final class POSTaggerTrainerTool
       throw new TerminateToolException(-1, e.getMessage());
     }
     
+    if (params.getTagDictCutoff() != null) {
+      try {
+        POSDictionary dict = postaggerFactory.getPOSDictionary();
+        if (dict == null) {
+          dict = postaggerFactory.createEmptyPOSDictionary();
+        }
+        if (dict instanceof MutableTagDictionary) {
+          POSTaggerME.populatePOSDictionary(sampleStream, dict,
+              params.getTagDictCutoff());
+        } else {
+          throw new IllegalArgumentException(
+              "Can't extend a POSDictionary that does not implement MutableTagDictionary.");
+        }
+        sampleStream.reset();
+      } catch (IOException e) {
+        throw new TerminateToolException(-1,
+            "IO error while creating/extending POS Dictionary: "
+                + e.getMessage());
+      }
+    }
+
     POSModel model;
     try {
       model = opennlp.tools.postag.POSTaggerME.train(factory.getLang(),

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java Tue May  8 20:43:34 2012
@@ -44,6 +44,8 @@ public class POSTaggerCrossValidator {
   private String factoryClassName;
   /* user can also send a ready to use factory */
   private POSTaggerFactory factory;
+
+  private Integer tagdicCutoff = null;
   
   /**
    * Creates a {@link POSTaggerCrossValidator} that builds a ngram dictionary
@@ -52,7 +54,7 @@ public class POSTaggerCrossValidator {
    */
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSDictionary tagDictionary,
-      Integer ngramCutoff, String factoryClass,
+      Integer ngramCutoff, Integer tagdicCutoff, String factoryClass,
       POSTaggerEvaluationMonitor... listeners) {
     this.languageCode = languageCode;
     this.params = trainParam;
@@ -61,6 +63,7 @@ public class POSTaggerCrossValidator {
     this.listeners = listeners;
     this.factoryClassName = factoryClass;
     this.ngramDictionary = null;
+    this.tagdicCutoff = tagdicCutoff;
   }
 
   /**
@@ -69,7 +72,7 @@ public class POSTaggerCrossValidator {
    */
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSTaggerFactory factory,
-      POSTaggerEvaluationMonitor... listeners) {
+      Integer tagdicCutoff, POSTaggerEvaluationMonitor... listeners) {
     this.languageCode = languageCode;
     this.params = trainParam;
     this.listeners = listeners;
@@ -77,9 +80,27 @@ public class POSTaggerCrossValidator {
     this.tagDictionary = null;
     this.ngramDictionary = null;
     this.ngramCutoff = null;
+    this.tagdicCutoff = tagdicCutoff;
   }
   
   /**
+   * Creates a {@link POSTaggerCrossValidator} using the given
+   * {@link POSTaggerFactory}.
+   */
+  public POSTaggerCrossValidator(String languageCode,
+      TrainingParameters trainParam, Integer posdicCutoff, POSTaggerFactory factory,
+      POSTaggerEvaluationMonitor... listeners) {
+    this.languageCode = languageCode;
+    this.params = trainParam;
+    this.listeners = listeners;
+    this.factory = factory;
+    this.tagDictionary = null;
+    this.ngramDictionary = null;
+    this.ngramCutoff = null;
+    this.tagdicCutoff = posdicCutoff;
+  }
+
+  /**
    * @deprecated use
    *             {@link #POSTaggerCrossValidator(String, TrainingParameters, POSTaggerFactory, POSTaggerEvaluationMonitor...)}
    *             instead and pass in a {@link TrainingParameters} object and a
@@ -87,7 +108,7 @@ public class POSTaggerCrossValidator {
    */
   public POSTaggerCrossValidator(String languageCode, ModelType modelType, POSDictionary tagDictionary,
       Dictionary ngramDictionary, int cutoff, int iterations) {
-    this(languageCode, create(modelType, cutoff, iterations), create(ngramDictionary, tagDictionary));
+    this(languageCode, create(modelType, cutoff, iterations), null, create(ngramDictionary, tagDictionary));
   }
   
   /**
@@ -98,7 +119,7 @@ public class POSTaggerCrossValidator {
    */
   public POSTaggerCrossValidator(String languageCode, ModelType modelType, POSDictionary tagDictionary,
       Dictionary ngramDictionary) {
-    this(languageCode, create(modelType, 5, 100), create(ngramDictionary, tagDictionary));
+    this(languageCode, create(modelType, 5, 100), null, create(ngramDictionary, tagDictionary));
   }
 
   /**
@@ -109,7 +130,7 @@ public class POSTaggerCrossValidator {
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSDictionary tagDictionary,
       POSTaggerEvaluationMonitor... listeners) {
-    this(languageCode, trainParam, create(null, tagDictionary), listeners);
+    this(languageCode, trainParam, null, create(null, tagDictionary), listeners);
   }
   
   /**
@@ -121,7 +142,7 @@ public class POSTaggerCrossValidator {
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSDictionary tagDictionary,
       Integer ngramCutoff, POSTaggerEvaluationMonitor... listeners) {
-    this(languageCode, trainParam, tagDictionary, ngramCutoff,
+    this(languageCode, trainParam, tagDictionary, ngramCutoff, null,
         POSTaggerFactory.class.getCanonicalName(), listeners);
   }
 
@@ -133,7 +154,7 @@ public class POSTaggerCrossValidator {
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSDictionary tagDictionary,
       Dictionary ngramDictionary, POSTaggerEvaluationMonitor... listeners) {
-    this(languageCode, trainParam, create(ngramDictionary, tagDictionary), listeners);
+    this(languageCode, trainParam, null, create(ngramDictionary, tagDictionary), listeners);
   }
   
   /**
@@ -173,6 +194,20 @@ public class POSTaggerCrossValidator {
         this.factory = POSTaggerFactory.create(this.factoryClassName,
             ngramDict, tagDictionary);
       }
+      if (this.tagdicCutoff != null) {
+        POSDictionary dict = this.factory.getPOSDictionary();
+        if (dict == null) {
+          dict = this.factory.createEmptyPOSDictionary();
+        }
+        if (dict instanceof MutableTagDictionary) {
+          POSTaggerME.populatePOSDictionary(trainingSampleStream, dict,
+              this.tagdicCutoff);
+        } else {
+          throw new IllegalArgumentException(
+              "Can't extend a POSDictionary that does not implement MutableTagDictionary.");
+        }
+        trainingSampleStream.reset();
+      }
       
       POSModel model = POSTaggerME.train(languageCode, trainingSampleStream,
           params, this.factory);
@@ -182,6 +217,11 @@ public class POSTaggerCrossValidator {
       evaluator.evaluate(trainingSampleStream.getTestSampleStream());
 
       wordAccuracy.add(evaluator.getWordAccuracy(), evaluator.getWordCount());
+
+      if (this.tagdicCutoff != null) {
+        this.factory.rereadPOSDictionary();
+      }
+
     }
   }
   

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java Tue May  8 20:43:34 2012
@@ -229,4 +229,13 @@ public class POSTaggerFactory extends Ba
     }
     return theFactory;
   }
+
+  public void rereadPOSDictionary() throws InvalidFormatException, IOException {
+    this.posDictionary = null;
+  }
+
+  public POSDictionary createEmptyPOSDictionary() {
+    this.posDictionary = new POSDictionary();
+    return this.posDictionary;
+  }
 }