You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by co...@apache.org on 2012/05/08 22:43:35 UTC
svn commit: r1335756 - in
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools:
cmdline/postag/POSTaggerCrossValidatorTool.java
cmdline/postag/POSTaggerTrainerTool.java
postag/POSTaggerCrossValidator.java postag/POSTaggerFactory.java
Author: colen
Date: Tue May 8 20:43:34 2012
New Revision: 1335756
URL: http://svn.apache.org/viewvc?rev=1335756&view=rev
Log:
OPENNLP-508 Included the mutable dictionary capability to the POSTagger trainer and cross validator tools
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java Tue May 8 20:43:34 2012
@@ -92,8 +92,8 @@ public final class POSTaggerCrossValidat
}
validator = new POSTaggerCrossValidator(factory.getLang(), mlParams,
- tagdict, params.getNgram(), params.getFactory(),
- missclassifiedListener, reportListener);
+ tagdict, params.getNgram(), params.getTagDictCutoff(),
+ params.getFactory(), missclassifiedListener, reportListener);
validator.evaluate(sampleStream, params.getFolds());
} catch (IOException e) {
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java Tue May 8 20:43:34 2012
@@ -28,6 +28,7 @@ import opennlp.tools.cmdline.TerminateTo
import opennlp.tools.cmdline.params.TrainingToolParams;
import opennlp.tools.cmdline.postag.POSTaggerTrainerTool.TrainerToolParams;
import opennlp.tools.dictionary.Dictionary;
+import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSDictionary;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
@@ -103,6 +104,27 @@ public final class POSTaggerTrainerTool
throw new TerminateToolException(-1, e.getMessage());
}
+ if (params.getTagDictCutoff() != null) {
+ try {
+ POSDictionary dict = postaggerFactory.getPOSDictionary();
+ if (dict == null) {
+ dict = postaggerFactory.createEmptyPOSDictionary();
+ }
+ if (dict instanceof MutableTagDictionary) {
+ POSTaggerME.populatePOSDictionary(sampleStream, dict,
+ params.getTagDictCutoff());
+ } else {
+ throw new IllegalArgumentException(
+ "Can't extend a POSDictionary that does not implement MutableTagDictionary.");
+ }
+ sampleStream.reset();
+ } catch (IOException e) {
+ throw new TerminateToolException(-1,
+ "IO error while creating/extending POS Dictionary: "
+ + e.getMessage());
+ }
+ }
+
POSModel model;
try {
model = opennlp.tools.postag.POSTaggerME.train(factory.getLang(),
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java Tue May 8 20:43:34 2012
@@ -44,6 +44,8 @@ public class POSTaggerCrossValidator {
private String factoryClassName;
/* user can also send a ready to use factory */
private POSTaggerFactory factory;
+
+ private Integer tagdicCutoff = null;
/**
* Creates a {@link POSTaggerCrossValidator} that builds a ngram dictionary
@@ -52,7 +54,7 @@ public class POSTaggerCrossValidator {
*/
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSDictionary tagDictionary,
- Integer ngramCutoff, String factoryClass,
+ Integer ngramCutoff, Integer tagdicCutoff, String factoryClass,
POSTaggerEvaluationMonitor... listeners) {
this.languageCode = languageCode;
this.params = trainParam;
@@ -61,6 +63,7 @@ public class POSTaggerCrossValidator {
this.listeners = listeners;
this.factoryClassName = factoryClass;
this.ngramDictionary = null;
+ this.tagdicCutoff = tagdicCutoff;
}
/**
@@ -69,7 +72,7 @@ public class POSTaggerCrossValidator {
*/
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSTaggerFactory factory,
- POSTaggerEvaluationMonitor... listeners) {
+ Integer tagdicCutoff, POSTaggerEvaluationMonitor... listeners) {
this.languageCode = languageCode;
this.params = trainParam;
this.listeners = listeners;
@@ -77,9 +80,27 @@ public class POSTaggerCrossValidator {
this.tagDictionary = null;
this.ngramDictionary = null;
this.ngramCutoff = null;
+ this.tagdicCutoff = tagdicCutoff;
}
/**
+ * Creates a {@link POSTaggerCrossValidator} using the given
+ * {@link POSTaggerFactory}.
+ */
+ public POSTaggerCrossValidator(String languageCode,
+ TrainingParameters trainParam, Integer posdicCutoff, POSTaggerFactory factory,
+ POSTaggerEvaluationMonitor... listeners) {
+ this.languageCode = languageCode;
+ this.params = trainParam;
+ this.listeners = listeners;
+ this.factory = factory;
+ this.tagDictionary = null;
+ this.ngramDictionary = null;
+ this.ngramCutoff = null;
+ this.tagdicCutoff = posdicCutoff;
+ }
+
+ /**
* @deprecated use
* {@link #POSTaggerCrossValidator(String, TrainingParameters, POSTaggerFactory, POSTaggerEvaluationMonitor...)}
* instead and pass in a {@link TrainingParameters} object and a
@@ -87,7 +108,7 @@ public class POSTaggerCrossValidator {
*/
public POSTaggerCrossValidator(String languageCode, ModelType modelType, POSDictionary tagDictionary,
Dictionary ngramDictionary, int cutoff, int iterations) {
- this(languageCode, create(modelType, cutoff, iterations), create(ngramDictionary, tagDictionary));
+ this(languageCode, create(modelType, cutoff, iterations), null, create(ngramDictionary, tagDictionary));
}
/**
@@ -98,7 +119,7 @@ public class POSTaggerCrossValidator {
*/
public POSTaggerCrossValidator(String languageCode, ModelType modelType, POSDictionary tagDictionary,
Dictionary ngramDictionary) {
- this(languageCode, create(modelType, 5, 100), create(ngramDictionary, tagDictionary));
+ this(languageCode, create(modelType, 5, 100), null, create(ngramDictionary, tagDictionary));
}
/**
@@ -109,7 +130,7 @@ public class POSTaggerCrossValidator {
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSDictionary tagDictionary,
POSTaggerEvaluationMonitor... listeners) {
- this(languageCode, trainParam, create(null, tagDictionary), listeners);
+ this(languageCode, trainParam, null, create(null, tagDictionary), listeners);
}
/**
@@ -121,7 +142,7 @@ public class POSTaggerCrossValidator {
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSDictionary tagDictionary,
Integer ngramCutoff, POSTaggerEvaluationMonitor... listeners) {
- this(languageCode, trainParam, tagDictionary, ngramCutoff,
+ this(languageCode, trainParam, tagDictionary, ngramCutoff, null,
POSTaggerFactory.class.getCanonicalName(), listeners);
}
@@ -133,7 +154,7 @@ public class POSTaggerCrossValidator {
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSDictionary tagDictionary,
Dictionary ngramDictionary, POSTaggerEvaluationMonitor... listeners) {
- this(languageCode, trainParam, create(ngramDictionary, tagDictionary), listeners);
+ this(languageCode, trainParam, null, create(ngramDictionary, tagDictionary), listeners);
}
/**
@@ -173,6 +194,20 @@ public class POSTaggerCrossValidator {
this.factory = POSTaggerFactory.create(this.factoryClassName,
ngramDict, tagDictionary);
}
+ if (this.tagdicCutoff != null) {
+ POSDictionary dict = this.factory.getPOSDictionary();
+ if (dict == null) {
+ dict = this.factory.createEmptyPOSDictionary();
+ }
+ if (dict instanceof MutableTagDictionary) {
+ POSTaggerME.populatePOSDictionary(trainingSampleStream, dict,
+ this.tagdicCutoff);
+ } else {
+ throw new IllegalArgumentException(
+ "Can't extend a POSDictionary that does not implement MutableTagDictionary.");
+ }
+ trainingSampleStream.reset();
+ }
POSModel model = POSTaggerME.train(languageCode, trainingSampleStream,
params, this.factory);
@@ -182,6 +217,11 @@ public class POSTaggerCrossValidator {
evaluator.evaluate(trainingSampleStream.getTestSampleStream());
wordAccuracy.add(evaluator.getWordAccuracy(), evaluator.getWordCount());
+
+ if (this.tagdicCutoff != null) {
+ this.factory.rereadPOSDictionary();
+ }
+
}
}
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java?rev=1335756&r1=1335755&r2=1335756&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java Tue May 8 20:43:34 2012
@@ -229,4 +229,13 @@ public class POSTaggerFactory extends Ba
}
return theFactory;
}
+
+ public void rereadPOSDictionary() throws InvalidFormatException, IOException {
+ this.posDictionary = null;
+ }
+
+ public POSDictionary createEmptyPOSDictionary() {
+ this.posDictionary = new POSDictionary();
+ return this.posDictionary;
+ }
}