You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/30 23:33:31 UTC

svn commit: r1563002 - in /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools: ml/TrainerFactory.java namefind/NameFinderME.java

Author: joern
Date: Thu Jan 30 22:33:30 2014
New Revision: 1563002

URL: http://svn.apache.org/r1563002
Log:
OPENNLP-641 Added new method to detect the trainer type to Trainer Factory and updated Name Finder ME to use it

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1563002&r1=1563001&r2=1563002&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jan 30 22:33:30 2014
@@ -27,8 +27,21 @@ import opennlp.tools.ml.maxent.quasinewt
 import opennlp.tools.ml.perceptron.PerceptronTrainer;
 import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
 
+// TODO: Another issue is that certain trainers will have certain properties,
+// the code using the trainer should have the possibilites to get these properties
+// in our case this could be communicated via the trainer interface itself!
+// For example via property methods.
+
+// 
+
 public class TrainerFactory {
 
+  public enum TrainerType {
+    EVENT_MODEL_TRAINER,
+    EVENT_MODEL_SEQUENCE_TRAINER,
+    SEQUENCE_TRAINER
+  }
+  
   // built-in trainers
   private static final Map<String, Class> BUILTIN_TRAINERS;
 
@@ -43,6 +56,7 @@ public class TrainerFactory {
     BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
   }
 
+  @Deprecated
   private static String getPluggableTrainerType(String className) {
     try {
       Class<?> trainerClass = Class.forName(className);
@@ -64,9 +78,51 @@ public class TrainerFactory {
     return null;
   }
   
-  // Note: A better way to indicate which training approach is necessary would be
-  // to use an enum which encodes the different possibilities ...
+  /**
+   * Determines the trainer type based on the ALGORITHM_PARAM value.
+   * 
+   * @param trainParams
+   * @return the trainer type or null if type couldn't be determined.
+   */
+  public static TrainerType getTrainerType(Map<String, String> trainParams){
+
+    String alogrithmValue = trainParams.get(AbstractTrainer.ALGORITHM_PARAM); 
+    
+    // Check if it is defaulting to the MAXENT trainer
+    if (alogrithmValue == null) {
+      return TrainerType.EVENT_MODEL_TRAINER;
+    }
+    
+    Class<?> trainerClass = BUILTIN_TRAINERS.get(alogrithmValue);
+    
+    // TODO: This will not work in an OSGi environment!
+    if (trainerClass == null) {
+      try {
+        trainerClass = Class.forName(alogrithmValue);
+      } catch (ClassNotFoundException e) {
+      }
+    }
+    
+    if(trainerClass != null) {
+      
+      if (EventTrainer.class.isAssignableFrom(trainerClass)) {
+        return TrainerType.EVENT_MODEL_TRAINER;
+      }
+      else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) {
+        return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER;
+      }
+      else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+        return TrainerType.SEQUENCE_TRAINER;
+      }
+    }
+    
+    return null;
+  }
   
+  /**
+   * @deprecated use getTrainerType instead!
+   */
+  @Deprecated
   public static boolean isSupportEvent(Map<String, String> trainParams) {
     
     String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -85,11 +141,18 @@ public class TrainerFactory {
     return true;
   }
   
+  /**
+   * @deprecated use getTrainerType instead!
+   */
   @Deprecated
   public static boolean isSupportSequence(Map<String, String> trainParams) {
     return isSupportEventModelSequenceTraining(trainParams);
   }
   
+  /**
+   * @deprecated use getTrainerType instead!
+   */
+  @Deprecated
   public static boolean isSupportEventModelSequenceTraining(Map<String, String> trainParams) {
     
     String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
@@ -104,6 +167,10 @@ public class TrainerFactory {
     return EventModelSequenceTrainer.SEQUENCE_VALUE.equals(trainerType);
   }
   
+  /**
+   * @deprecated use getTrainerType instead!
+   */
+  @Deprecated
   public static boolean isSupportSequenceTraining(Map<String, String> trainParams) {
     String trainerType = trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM);
     
@@ -122,11 +189,7 @@ public class TrainerFactory {
   // is support sequence ?
   
   /**
-   * This method is deprecated and should not be used! <br>
-   * Use {@link TrainerFactory#isSupportEventModelSequenceTraining(Map)} instead.
-   * 
-   * @param trainParams
-   * @return
+   * @deprecated use getTrainerType instead!
    */
   @Deprecated
   public static boolean isSequenceTraining(Map<String, String> trainParams) {
@@ -137,7 +200,7 @@ public class TrainerFactory {
   
   public static SequenceTrainer getSequenceModelTrainer(Map<String, String> trainParams,
       Map<String, String> reportMap) {
-    String trainerType = getTrainerType(trainParams);
+    String trainerType = getTrainerTypeInt(trainParams);
     if (BUILTIN_TRAINERS.containsKey(trainerType)) {
       return TrainerFactory.<SequenceTrainer> create(
           BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
@@ -148,9 +211,15 @@ public class TrainerFactory {
     
   }
   
+  public static EventModelSequenceTrainer getEventModelSequenceTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    return getSequenceTrainer(trainParams, reportMap);
+  }
+  
+  @Deprecated
   public static EventModelSequenceTrainer getSequenceTrainer(
       Map<String, String> trainParams, Map<String, String> reportMap) {
-    String trainerType = getTrainerType(trainParams);
+    String trainerType = getTrainerTypeInt(trainParams);
     if (BUILTIN_TRAINERS.containsKey(trainerType)) {
       return TrainerFactory.<EventModelSequenceTrainer> create(
           BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
@@ -162,7 +231,7 @@ public class TrainerFactory {
 
   public static EventTrainer getEventTrainer(Map<String, String> trainParams,
       Map<String, String> reportMap) {
-    String trainerType = getTrainerType(trainParams);
+    String trainerType = getTrainerTypeInt(trainParams);
     if(trainerType == null) {
       // default to MAXENT
       return new GIS(trainParams, reportMap);
@@ -230,7 +299,7 @@ public class TrainerFactory {
     return false;
   }
 
-  private static String getTrainerType(Map<String, String> trainParams) {
+  private static String getTrainerTypeInt(Map<String, String> trainParams) {
     return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
   }
 

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java?rev=1563002&r1=1563001&r2=1563002&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/namefind/NameFinderME.java Thu Jan 30 22:33:30 2014
@@ -29,8 +29,11 @@ import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
+import opennlp.tools.ml.EventModelSequenceTrainer;
+import opennlp.tools.ml.EventTrainer;
 import opennlp.tools.ml.SequenceTrainer;
 import opennlp.tools.ml.TrainerFactory;
+import opennlp.tools.ml.TrainerFactory.TrainerType;
 import opennlp.tools.ml.model.EventStream;
 import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.ml.model.SequenceClassificationModel;
@@ -339,26 +342,31 @@ public class NameFinderME implements Tok
      
      SequenceClassificationModel<String> seqModel = null;
      
-     if (TrainerFactory.isSupportEvent((trainParams.getSettings()))) {
+     TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
+     
+     if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) {
        EventStream eventStream = new NameFinderEventStream(samples, type,
            new DefaultNameContextGenerator(featureGenerator));
 
-       nameFinderModel = TrainUtil.train(eventStream, trainParams.getSettings(), manifestInfoEntries);
+       EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries);
+       nameFinderModel = trainer.train(eventStream);
      }
-     else if (TrainerFactory.isSupportEventModelSequenceTraining((trainParams.getSettings()))) {
+     else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
        NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator);
 
-       nameFinderModel = TrainUtil.train(ss, trainParams.getSettings(), manifestInfoEntries);
+       EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(
+           trainParams.getSettings(), manifestInfoEntries);
+       nameFinderModel = trainer.train(ss);
      }
-     else if (TrainerFactory.isSupportSequenceTraining((trainParams.getSettings()))) {
+     else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
        SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
            trainParams.getSettings(), manifestInfoEntries);
-
+       
        NameSampleSequenceStream ss = new NameSampleSequenceStream(samples, featureGenerator, false);
        seqModel = trainer.train(ss);
      }
      else {
-       throw new IllegalStateException("Unexpected trainer type required!");
+       throw new IllegalStateException("Unexpected trainer type!");
      }
      
      // depending on which one is not null!
@@ -366,9 +374,10 @@ public class NameFinderME implements Tok
        return new TokenNameFinderModel(languageCode, seqModel, null,
            resources, manifestInfoEntries);
      }
-     
-     return new TokenNameFinderModel(languageCode, nameFinderModel,
-         resources, manifestInfoEntries);
+     else {
+       return new TokenNameFinderModel(languageCode, nameFinderModel,
+           resources, manifestInfoEntries);
+     }
    }
 
   /**