You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/30 23:33:31 UTC
svn commit: r1563002 - in
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools:
ml/TrainerFactory.java namefind/NameFinderME.java
Author: joern
Date: Thu Jan 30 22:33:30 2014
New Revision: 1563002
URL: http://svn.apache.org/r1563002
Log:
OPENNLP-641 Added new method to detect the trainer type to Trainer Factory and updated Name Finder ME to use it
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1563002&r1=1563001&r2=1563002&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jan 30 22:33:30 2014
@@ -27,8 +27,21 @@ import opennlp.tools.ml.maxent.quasinewt
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
+// TODO: Another issue is that certain trainers will have certain properties,
+// the code using the trainer should have the possibilites to get these properties
+// in our case this could be communicated via the trainer interface itself!
+// For example via property methods.
+
+//
+
public class TrainerFactory {
+ public enum TrainerType {
+ EVENT_MODEL_TRAINER,
+ EVENT_MODEL_SEQUENCE_TRAINER,
+ SEQUENCE_TRAINER
+ }
+
// built-in trainers
private static final Map<String, Class> BUILTIN_TRAINERS;
@@ -43,6 +56,7 @@ public class TrainerFactory {
BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
}
+ @Deprecated
private static String getPluggableTrainerType(String className) {
try {
Class<?> trainerClass = Class.forName(className);
@@ -64,9 +78,51 @@ public class TrainerFactory {
return null;
}
- // Note: A better way to indicate which training approach is necessary would be
- // to use an enum which encodes the different possibilities ...
+ /**
+ * Determines the trainer type based on the ALGORITHM_PARAM value.
+ *
+ * @param trainParams
+ * @return the trainer type or null if type couldn't be determined.
+ */
+ public static TrainerType getTrainerType(Map<String, String> trainParams){
+
+ String alogrithmValue = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+
+ // Check if it is defaulting to the MAXENT trainer
+ if (alogrithmValue == null) {
+ return TrainerType.EVENT_MODEL_TRAINER;
+ }
+
+ Class<?> trainerClass = BUILTIN_TRAINERS.get(alogrithmValue);
+
+ // TODO: This will not work in an OSGi environment!
+ if (trainerClass == null) {
+ try {
+ trainerClass = Class.forName(alogrithmValue);
+ } catch (ClassNotFoundException e) {
+ }
+ }
+
+ if(trainerClass != null) {
+
+ if (EventTrainer.class.isAssignableFrom(trainerClass)) {
+ return TrainerType.EVENT_MODEL_TRAINER;
+ }
+ else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) {
+ return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER;
+ }
+ else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+ return TrainerType.SEQUENCE_TRAINER;
+ }
+ }
+
+ return null;
+ }
+ /**
+ * @deprecated use getTrainerType instead!
+ */
+ @Deprecated
public static boolean isSupportEvent(Map<String, String> trainParams) {
String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -85,11 +141,18 @@ public class TrainerFactory {
return true;
}
+ /**
+ * @deprecated use getTrainerType instead!
+ */
@Deprecated
public static boolean isSupportSequence(Map<String, String> trainParams) {
return isSupportEventModelSequenceTraining(trainParams);
}
+ /**
+ * @deprecated use getTrainerType instead!
+ */
+ @Deprecated
public static boolean isSupportEventModelSequenceTraining(Map<String, String> trainParams) {
String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -104,6 +167,10 @@ public class TrainerFactory {
return EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
}
+ /**
+ * @deprecated use getTrainerType instead!
+ */
+ @Deprecated
public static boolean isSupportSequenceTraining(Map<String, String> trainParams) {
String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -122,11 +189,7 @@ public class TrainerFactory {
// is support sequence ?
/**
- * This method is deprecated and should not be used! <br>
- * Use {@link TrainerFactory#isSupportEventModelSequenceTraining(Map)} instead.
- *
- * @param trainParams
- * @return
+ * @deprecated use getTrainerType instead!
*/
@Deprecated
public static boolean isSequenceTraining(Map<String, String> trainParams) {
@@ -137,7 +200,7 @@ public class TrainerFactory {
public static SequenceTrainer getSequenceModelTrainer(Map<String, String> trainParams,
Map<String, String> reportMap) {
- String trainerType = getTrainerType(trainParams);
+ String trainerType = getTrainerTypeInt(trainParams);
if (BUILTIN_TRAINERS.containsKey(trainerType)) {
return TrainerFactory.<SequenceTrainer> create(
BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
@@ -148,9 +211,15 @@ public class TrainerFactory {
}
+ public static EventModelSequenceTrainer getEventModelSequenceTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ return getSequenceTrainer(trainParams, reportMap);
+ }
+
+ @Deprecated
public static EventModelSequenceTrainer getSequenceTrainer(
Map<String, String> trainParams, Map<String, String> reportMap) {
- String trainerType = getTrainerType(trainParams);
+ String trainerType = getTrainerTypeInt(trainParams);
if (BUILTIN_TRAINERS.containsKey(trainerType)) {
return TrainerFactory.<EventModelSequenceTrainer> create(
BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
@@ -162,7 +231,7 @@ public class TrainerFactory {
public static EventTrainer getEventTrainer(Map<String, String> trainParams,
Map<String, String> reportMap) {
- String trainerType = getTrainerType(trainParams);
+ String trainerType = getTrainerTypeInt(trainParams);
if(trainerType == null) {
// default to MAXENT
return new GIS(trainParams, reportMap);
@@ -230,7 +299,7 @@ public class TrainerFactory {
return false;
}
- private static String getTrainerType(Map<String, String> trainParams) {
+ private static String getTrainerTypeInt(Map<String, String> trainParams) {
return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
}
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java?rev=1563002&r1=1563001&r2=1563002&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java Thu Jan 30 22:33:30 2014
@@ -29,8 +29,11 @@ import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import opennlp.tools.ml.EventModelSequenceTrainer;
+import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
+import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.model.EventStream;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
@@ -339,26 +342,31 @@ public class NameFinderME implements Tok
SequenceClassificationModel<String> seqModel = null;
- if (TrainerFactory.isSupportEvent((trainParams.getSettings()))) {
+ TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
+
+ if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
EventStream eventStream = new NameFinderEventStream(samples, type,
new DefaultNameContextGenerator(featureGenerator));
- nameFinderModel = TrainUtil.train(eventStream, trainParams.getSettings(), manifestInfoEntries);
+ EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries);
+ nameFinderModel = trainer.train(eventStream);
}
- else if (TrainerFactory.isSupportEventModelSequenceTraining((trainParams.getSettings()))) {
+ else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator);
- nameFinderModel = TrainUtil.train(ss, trainParams.getSettings(), manifestInfoEntries);
+ EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
+ trainParams.getSettings(), manifestInfoEntries);
+ nameFinderModel = trainer.train(ss);
}
- else if (TrainerFactory.isSupportSequenceTraining((trainParams.getSettings()))) {
+ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams.getSettings(), manifestInfoEntries);
-
+
NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator, false);
seqModel = trainer.train(ss);
}
else {
- throw new IllegalStateException("Unexpected trainer type required!");
+ throw new IllegalStateException("Unexpected trainer type!");
}
// depending on which one is not null!
@@ -366,9 +374,10 @@ public class NameFinderME implements Tok
return new TokenNameFinderModel(languageCode, seqModel, null,
resources, manifestInfoEntries);
}
-
- return new TokenNameFinderModel(languageCode, nameFinderModel,
- resources, manifestInfoEntries);
+ else {
+ return new TokenNameFinderModel(languageCode, nameFinderModel,
+ resources, manifestInfoEntries);
+ }
}
/**