You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/02/18 16:37:26 UTC

svn commit: r1569390 - in /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml: TrainerFactory.java model/TrainUtil.java

Author: joern
Date: Tue Feb 18 15:37:25 2014
New Revision: 1569390

URL: http://svn.apache.org/r1569390
Log:
OPENNLP-636 Removed usage of Class.forName from all non-deprecated methods. OpenNLP code base should be converted to use the new methods.

Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1569390&r1=1569389&r2=1569390&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Tue Feb 18 15:37:25 2014
@@ -26,13 +26,8 @@ import opennlp.tools.ml.maxent.GIS;
 import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
 import opennlp.tools.ml.perceptron.PerceptronTrainer;
 import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
-
-// TODO: Another issue is that certain trainers will have certain properties,
-// the code using the trainer should have the possibilites to get these properties
-// in our case this could be communicated via the trainer interface itself!
-// For example via property methods.
-
-// 
+import opennlp.tools.util.ext.ExtensionLoader;
+import opennlp.tools.util.ext.ExtensionNotLoadedException;
 
 public class TrainerFactory {
 
@@ -95,14 +90,6 @@ public class TrainerFactory {
     
     Class<?> trainerClass = BUILTIN_TRAINERS.get(alogrithmValue);
     
-    // TODO: This will not work in an OSGi environment!
-    if (trainerClass == null) {
-      try {
-        trainerClass = Class.forName(alogrithmValue);
-      } catch (ClassNotFoundException e) {
-      }
-    }
-    
     if(trainerClass != null) {
       
       if (EventTrainer.class.isAssignableFrom(trainerClass)) {
@@ -115,7 +102,30 @@ public class TrainerFactory {
         return TrainerType.SEQUENCE_TRAINER;
       }
     }
+
+    // Try to load the different trainers, and return the type on success
     
+    try {
+      ExtensionLoader.instantiateExtension(EventTrainer.class, alogrithmValue);
+      return TrainerType.EVENT_MODEL_TRAINER; 
+    }
+    catch (ExtensionNotLoadedException e) {
+    }
+    
+    try {
+      ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, alogrithmValue);
+      return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER;
+    }
+    catch (ExtensionNotLoadedException e) {
+    }
+
+    try {
+      ExtensionLoader.instantiateExtension(SequenceTrainer.class, alogrithmValue);
+      return TrainerType.SEQUENCE_TRAINER;
+    }
+    catch (ExtensionNotLoadedException e) {
+    }
+
     return null;
   }
   
@@ -197,52 +207,59 @@ public class TrainerFactory {
         .equals(trainParams.get(AbstractTrainer.ALGORITHM_PARAM));
   }
   
-  
   public static SequenceTrainer getSequenceModelTrainer(Map<String, String> trainParams,
       Map<String, String> reportMap) {
-    String trainerType = getTrainerTypeInt(trainParams);
-    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-      return TrainerFactory.<SequenceTrainer> create(
-          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
-    } else {
-      return TrainerFactory.<SequenceTrainer> create(trainerType, trainParams,
-          reportMap);
-    }
+    String trainerType = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
     
+    if (trainerType != null) {
+      if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+        return TrainerFactory.<SequenceTrainer> createBuiltinTrainer(
+            BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+      } else {
+        return ExtensionLoader.instantiateExtension(SequenceTrainer.class, trainerType);
+      }
+    }
+    else {
+      throw new IllegalArgumentException("Trainer type couldn't be determined!");
+    }
   }
   
   public static EventModelSequenceTrainer getEventModelSequenceTrainer(Map<String, String> trainParams,
       Map<String, String> reportMap) {
-    return getSequenceTrainer(trainParams, reportMap);
+    String trainerType = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+    if (trainerType != null) {
+      if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+        return TrainerFactory.<EventModelSequenceTrainer> createBuiltinTrainer(
+            BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+      } else {
+        return ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class, trainerType);
+      }
+    }
+    else {
+      throw new IllegalArgumentException("Trainer type couldn't be determined!");
+    }
   }
   
   @Deprecated
   public static EventModelSequenceTrainer getSequenceTrainer(
       Map<String, String> trainParams, Map<String, String> reportMap) {
-    String trainerType = getTrainerTypeInt(trainParams);
-    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-      return TrainerFactory.<EventModelSequenceTrainer> create(
-          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
-    } else {
-      return TrainerFactory.<EventModelSequenceTrainer> create(trainerType, trainParams,
-          reportMap);
-    }
+    return getEventModelSequenceTrainer(trainParams, reportMap);
   }
 
   public static EventTrainer getEventTrainer(Map<String, String> trainParams,
       Map<String, String> reportMap) {
-    String trainerType = getTrainerTypeInt(trainParams);
-    if(trainerType == null) {
+    String trainerType = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+    if (trainerType == null) {
       // default to MAXENT
       return new GIS(trainParams, reportMap);
     }
-    
-    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
-      return TrainerFactory.<EventTrainer> create(
-          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
-    } else {
-      return TrainerFactory.<EventTrainer> create(trainerType, trainParams,
-          reportMap);
+    else {
+      if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+        return TrainerFactory.<EventTrainer> createBuiltinTrainer(
+            BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+      } else {
+        return ExtensionLoader.instantiateExtension(EventTrainer.class, trainerType);
+      }
     }
   }
   
@@ -252,11 +269,9 @@ public class TrainerFactory {
     
     String algorithmName = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
     
-    // to check the algorithm we verify if it is a built in trainer, or if we can instantiate
-    // one if it is a class name
-    
+    // If a trainer type can be determined, then the trainer is valid!
     if (algorithmName != null && 
-        !(BUILTIN_TRAINERS.containsKey(algorithmName) || canLoadTrainer(algorithmName))) {
+        !(BUILTIN_TRAINERS.containsKey(algorithmName) || getTrainerType(trainParams) != null)) {
       return false;
     }
 
@@ -285,44 +300,7 @@ public class TrainerFactory {
     return true;
   }
 
-  private static boolean canLoadTrainer(String className) {
-    try {
-      Class<?> trainerClass = Class.forName(className);
-      if(trainerClass != null &&
-          (EventTrainer.class.isAssignableFrom(trainerClass)
-              || EventModelSequenceTrainer.class.isAssignableFrom(trainerClass) || SequenceTrainer.class.isAssignableFrom(trainerClass))) {
-        return true;
-      }
-    } catch (ClassNotFoundException e) {
-      // fail
-    }
-    return false;
-  }
-
-  private static String getTrainerTypeInt(Map<String, String> trainParams) {
-    return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
-  }
-
-  private static <T> T create(String className,
-      Map<String, String> trainParams, Map<String, String> reportMap) {
-    T theFactory = null;
-
-    try {
-      // TODO: won't work in OSGi!
-      Class<T> trainerClass = (Class<T>) Class.forName(className);
-      
-      theFactory = create(trainerClass, trainParams, reportMap);
-    } catch (Exception e) {
-      String msg = "Could not instantiate the " + className
-          + ". The initialization throw an exception.";
-      System.err.println(msg);
-      e.printStackTrace();
-      throw new IllegalArgumentException(msg, e);
-    }
-    return theFactory;
-  }
-
-  private static <T> T create(Class<T> trainerClass,
+  private static <T> T createBuiltinTrainer(Class<T> trainerClass,
       Map<String, String> trainParams, Map<String, String> reportMap) {
     T theTrainer = null;
     if (trainerClass != null) {

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java?rev=1569390&r1=1569389&r2=1569390&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java Tue Feb 18 15:37:25 2014
@@ -73,7 +73,7 @@ public class TrainUtil {
     if(!TrainerFactory.isSupportSequence(trainParams)) {
       throw new IllegalArgumentException("EventTrain is not supported");
     }
-    EventModelSequenceTrainer trainer = TrainerFactory.getSequenceTrainer(trainParams, reportMap);
+    EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(trainParams, reportMap);
     
     return trainer.train(events);
   }