You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/30 14:38:56 UTC
svn commit: r1562819 -
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
Author: joern
Date: Thu Jan 30 13:38:56 2014
New Revision: 1562819
URL: http://svn.apache.org/r1562819
Log:
OPENNLP-641 Extended TrainerFactory to support true sequence training
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1562819&r1=1562818&r2=1562819&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jan 30 13:38:56 2014
@@ -54,6 +54,9 @@ public class TrainerFactory {
else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) {
return EventModelSequenceTrainer.SEQUENCE_VALUE;
}
+ else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+ return SequenceTrainer.SEQUENCE_VALUE;
+ }
}
} catch (ClassNotFoundException e) {
}
@@ -61,6 +64,9 @@ public class TrainerFactory {
return "UNKOWN";
}
+ // Note: A better way to indicate which training approach is necessary would be
+ // to use an enum which encodes the different possibilities ...
+
public static boolean isSupportEvent(Map<String, String> trainParams) {
String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -72,15 +78,19 @@ public class TrainerFactory {
}
}
- if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
- return EventTrainer.EVENT_VALUE.equals(trainParams
- .get(AbstractTrainer.TRAINER_TYPE_PARAM));
+ if (trainerType != null) {
+ return EventTrainer.EVENT_VALUE.equals(trainerType);
}
return true;
}
-
+
+ @Deprecated
public static boolean isSupportSequence(Map<String, String> trainParams) {
+ return isSupportEventModelSequenceTraining(trainParams);
+ }
+
+ public static boolean isSupportEventModelSequenceTraining(Map<String, String> trainParams) {
String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -91,16 +101,29 @@ public class TrainerFactory {
}
}
- if (EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType)) {
- return true;
+ return EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
+ }
+
+ public static boolean isSupportSequenceTraining(Map<String, String> trainParams) {
+ String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
+
+ if (trainerType == null) {
+ String alogrithmValue = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+ if (alogrithmValue != null) {
+ trainerType = getPluggableTrainerType(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
+ }
}
- return false;
+ return SequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
}
-
+
+ // TODO: How to do the testing ?!
+ // is support event sequence ?
+ // is support sequence ?
+
/**
* This method is deprecated and should not be used! <br>
- * Use {@link TrainerFactory#isSupportSequence(Map)} instead.
+ * Use {@link TrainerFactory#isSupportEventModelSequenceTraining(Map)} instead.
*
* @param trainParams
* @return
@@ -111,6 +134,20 @@ public class TrainerFactory {
.equals(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
}
+
+ public static SequenceTrainer getSequenceModelTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ String trainerType = getTrainerType(trainParams);
+ if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+ return TrainerFactory.<SequenceTrainer> create(
+ BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+ } else {
+ return TrainerFactory.<SequenceTrainer> create(trainerType, trainParams,
+ reportMap);
+ }
+
+ }
+
public static EventModelSequenceTrainer getSequenceTrainer(
Map<String, String> trainParams, Map<String, String> reportMap) {
String trainerType = getTrainerType(trainParams);
@@ -184,7 +221,7 @@ public class TrainerFactory {
Class<?> trainerClass = Class.forName(className);
if(trainerClass != null &&
(EventTrainer.class.isAssignableFrom(trainerClass)
- || EventModelSequenceTrainer.class.isAssignableFrom(trainerClass))) {
+ || EventModelSequenceTrainer.class.isAssignableFrom(trainerClass) || SequenceTrainer.class.isAssignableFrom(trainerClass))) {
return true;
}
} catch (ClassNotFoundException e) {