You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/17 20:23:10 UTC

svn commit: r1559228 - /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java

Author: joern
Date: Fri Jan 17 19:23:10 2014
New Revision: 1559228

URL: http://svn.apache.org/r1559228
Log:
OPENNLP-635 pluggable trainers can now be correctly classified as event or sequence

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1559228&r1=1559227&r2=1559228&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Fri Jan 17 19:23:10 2014
@@ -43,23 +43,53 @@ public class TrainerFactory {
     BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
   }
 
-  public static boolean isSupportEvent(Map<String, String> trainParams) {
-    if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
-      if(EventTrainer.EVENT_VALUE.equals(trainParams
-            .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
-        return true;
+  private static String getPluggableTrainerType(String className) {
+    try {
+      Class<?> trainerClass = Class.forName(className);
+      if(trainerClass != null) {
+        
+        if (EventTrainer.class.isAssignableFrom(trainerClass)) {
+          return EventTrainer.EVENT_VALUE;
+        }
+        else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+          return SequenceTrainer.SEQUENCE_VALUE;
+        }
       }
-      return false;
-    } else {
-      return true; // default to event train
+    } catch (ClassNotFoundException e) {
+    }
+    
+    return "UNKOWN";
+  }
+  
+  public static boolean isSupportEvent(Map<String, String> trainParams) {
+    
+    String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
+    
+    if (trainerType == null) {
+      trainerType = getPluggableTrainerType(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
     }
+    
+    if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
+      return EventTrainer.EVENT_VALUE.equals(trainParams
+          .get(AbstractTrainer.TRAINER_TYPE_PARAM));
+    } 
+    
+    // default
+    return true;
   }
 
   public static boolean isSupportSequence(Map<String, String> trainParams) {
-    if (SequenceTrainer.SEQUENCE_VALUE.equals(trainParams
-        .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
+    
+    String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
+    
+    if (trainerType == null) {
+      trainerType = getPluggableTrainerType(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
+    }
+    
+    if (SequenceTrainer.SEQUENCE_VALUE.equals(trainerType)) {
       return true;
     }
+    
     return false;
   }