You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/30 14:38:56 UTC

svn commit: r1562819 - /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java

Author: joern
Date: Thu Jan 30 13:38:56 2014
New Revision: 1562819

URL: http://svn.apache.org/r1562819
Log:
OPENNLP-641 Extended TrainerFactory to support true sequence training

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1562819&r1=1562818&r2=1562819&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jan 30 13:38:56 2014
@@ -54,6 +54,9 @@ public class TrainerFactory {
         else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) {
           return EventModelSequenceTrainer.SEQUENCE_VALUE;
         }
+        else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+          return SequenceTrainer.SEQUENCE_VALUE;
+        }
       }
     } catch (ClassNotFoundException e) {
     }
@@ -61,6 +64,9 @@ public class TrainerFactory {
     return "UNKOWN";
   }
   
+  // Note: A better way to indicate which training approach is necessary would be
+  // to use an enum which encodes the different possibilities ...
+  
   public static boolean isSupportEvent(Map<String, String> trainParams) {
     
     String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -72,15 +78,19 @@ public class TrainerFactory {
       }
     }
     
-    if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
-      return EventTrainer.EVENT_VALUE.equals(trainParams
-          .get(AbstractTrainer.TRAINER_TYPE_PARAM));
+    if (trainerType != null) {
+      return EventTrainer.EVENT_VALUE.equals(trainerType);
     } 
     
     return true;
   }
-
+  
+  @Deprecated
   public static boolean isSupportSequence(Map<String, String> trainParams) {
+    return isSupportEventModelSequenceTraining(trainParams);
+  }
+  
+  public static boolean isSupportEventModelSequenceTraining(Map<String, String> trainParams) {
     
     String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
     
@@ -91,16 +101,29 @@ public class TrainerFactory {
       }
     }
     
-    if (EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType)) {
-      return true;
+    return EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
+  }
+  
+  public static boolean isSupportSequenceTraining(Map<String, String> trainParams) {
+    String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
+    
+    if (trainerType == null) {
+      String alogrithmValue = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+      if (alogrithmValue != null) {
+        trainerType = getPluggableTrainerType(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
+      }
     }
     
-    return false;
+    return SequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
   }
-
+  
+  // TODO: How to do the testing ?!
+  // is support event sequence ?
+  // is support sequence ?
+  
   /**
    * This method is deprecated and should not be used! <br>
-   * Use {@link TrainerFactory#isSupportSequence(Map)} instead.
+   * Use {@link TrainerFactory#isSupportEventModelSequenceTraining(Map)} instead.
    * 
    * @param trainParams
    * @return
@@ -111,6 +134,20 @@ public class TrainerFactory {
         .equals(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
   }
   
+  
+  public static SequenceTrainer getSequenceModelTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    String trainerType = getTrainerType(trainParams);
+    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+      return TrainerFactory.<SequenceTrainer> create(
+          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+    } else {
+      return TrainerFactory.<SequenceTrainer> create(trainerType, trainParams,
+          reportMap);
+    }
+    
+  }
+  
   public static EventModelSequenceTrainer getSequenceTrainer(
       Map<String, String> trainParams, Map<String, String> reportMap) {
     String trainerType = getTrainerType(trainParams);
@@ -184,7 +221,7 @@ public class TrainerFactory {
       Class<?> trainerClass = Class.forName(className);
       if(trainerClass != null &&
           (EventTrainer.class.isAssignableFrom(trainerClass)
-              || EventModelSequenceTrainer.class.isAssignableFrom(trainerClass))) {
+              || EventModelSequenceTrainer.class.isAssignableFrom(trainerClass) || SequenceTrainer.class.isAssignableFrom(trainerClass))) {
         return true;
       }
     } catch (ClassNotFoundException e) {