You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by co...@apache.org on 2013/09/10 17:07:35 UTC

svn commit: r1521519 - in /opennlp/trunk/opennlp-tools/src: main/java/opennlp/tools/ml/ main/java/opennlp/tools/ml/model/ main/java/opennlp/tools/util/model/ test/java/opennlp/tools/ml/ test/java/opennlp/tools/ml/maxent/ test/java/opennlp/tools/ml/perc...

Author: colen
Date: Tue Sep 10 15:07:35 2013
New Revision: 1521519

URL: http://svn.apache.org/r1521519
Log:
OPENNLP-581 Deprecated TrainUtil methods and removed duplicated references to constants. Moved the isValid method to TrainerFactory, updated it to work with class names and created a junit test to validate it.

Added:
    opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java   (with props)
    opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java   (with props)
    opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java   (with props)
Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
    opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
    opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.java

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1521519&r1=1521518&r2=1521519&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Tue Sep 10 15:07:35 2013
@@ -18,10 +18,8 @@
 package opennlp.tools.ml;
 
 import java.lang.reflect.Constructor;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import opennlp.tools.ml.maxent.GIS;
@@ -93,6 +91,59 @@ public class TrainerFactory {
           reportMap);
     }
   }
+  
+  public static boolean isValid(Map<String, String> trainParams) {
+
+    // TODO: Need to validate all parameters correctly ... error prone?!
+    
+    String algorithmName = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+    
+    // to check the algorithm we verify if it is a built in trainer, or if we can instantiate
+    // one if it is a class name
+    
+    if (algorithmName != null && 
+        !(BUILTIN_TRAINERS.containsKey(algorithmName) || canLoadTrainer(algorithmName))) {
+      return false;
+    }
+
+    try {
+      String cutoffString = trainParams.get(AbstractTrainer.CUTOFF_PARAM);
+      if (cutoffString != null) Integer.parseInt(cutoffString);
+      
+      String iterationsString = trainParams.get(AbstractTrainer.ITERATIONS_PARAM);
+      if (iterationsString != null) Integer.parseInt(iterationsString);
+    }
+    catch (NumberFormatException e) {
+      return false;
+    }
+    
+    String dataIndexer = trainParams.get(AbstractEventTrainer.DATA_INDEXER_PARAM);
+    
+    if (dataIndexer != null) {
+      if (!(AbstractEventTrainer.DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexer) 
+          || AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexer))) {
+        return false;
+      }
+    }
+    
+    // TODO: Check data indexing ... 
+     
+    return true;
+  }
+
+  private static boolean canLoadTrainer(String className) {
+    try {
+      Class<?> trainerClass = Class.forName(className);
+      if(trainerClass != null &&
+          (EventTrainer.class.isAssignableFrom(trainerClass)
+              || SequenceTrainer.class.isAssignableFrom(trainerClass))) {
+        return true;
+      }
+    } catch (ClassNotFoundException e) {
+      // fail
+    }
+    return false;
+  }
 
   private static String getTrainerType(Map<String, String> trainParams) {
     return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
@@ -105,6 +156,7 @@ public class TrainerFactory {
     try {
       // TODO: won't work in OSGi!
       Class<T> trainerClass = (Class<T>) Class.forName(className);
+      
       theFactory = create(trainerClass, trainParams, reportMap);
     } catch (Exception e) {
       String msg = "Could not instantiate the " + className

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java?rev=1521519&r1=1521518&r2=1521519&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java Tue Sep 10 15:07:35 2013
@@ -25,69 +25,22 @@ import java.util.Map;
 import opennlp.tools.ml.EventTrainer;
 import opennlp.tools.ml.SequenceTrainer;
 import opennlp.tools.ml.TrainerFactory;
-import opennlp.tools.ml.maxent.GIS;
-import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
-import opennlp.tools.ml.perceptron.PerceptronTrainer;
-import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
 
 public class TrainUtil {
 
-  public static final String ALGORITHM_PARAM = "Algorithm";
-  
-  public static final String MAXENT_VALUE = "MAXENT";
-  public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
-  public static final String PERCEPTRON_VALUE = "PERCEPTRON";
-  public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
-  
-  public static final String CUTOFF_PARAM = "Cutoff";
-  
-  public static final String ITERATIONS_PARAM = "Iterations";
-  
-  public static final String DATA_INDEXER_PARAM = "DataIndexer";
-  public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
-  public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
-  
+  /**
+   * @deprecated Use {@link TrainerFactory#isValid(Map)} instead.
+   */
   public static boolean isValid(Map<String, String> trainParams) {
-
-    // TODO: Need to validate all parameters correctly ... error prone?!
-    
-    String algorithmName = trainParams.get(ALGORITHM_PARAM);
-
-    if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
-    	MAXENT_QN_VALUE.equals(algorithmName) ||
-        PERCEPTRON_VALUE.equals(algorithmName) ||
-        PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
-      return false;
-    }
-
-    try {
-      String cutoffString = trainParams.get(CUTOFF_PARAM);
-      if (cutoffString != null) Integer.parseInt(cutoffString);
-      
-      String iterationsString = trainParams.get(ITERATIONS_PARAM);
-      if (iterationsString != null) Integer.parseInt(iterationsString);
-    }
-    catch (NumberFormatException e) {
-      return false;
-    }
-    
-    String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
-    
-    if (dataIndexer != null) {
-      if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
-        return false;
-      }
-    }
-    
-    // TODO: Check data indexing ... 
-     
-    return true;
+    return TrainerFactory.isValid(trainParams);
   }
   
-  
-  
   // TODO: Need a way to report results and settings back for inclusion in model ...
   
+  /**
+   * @deprecated Use {@link TrainerFactory#getEventTrainer(Map, Map)} to get an
+   *             {@link EventTrainer} instead.
+   */
   public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap) 
       throws IOException {
     
@@ -100,13 +53,19 @@ public class TrainUtil {
   }
   
   /**
-   * Detects if the training algorithm requires sequence based feature generation
-   * or not.
+   * Detects if the training algorithm requires sequence based feature
+   * generation or not.
+   * 
+   * @deprecated Use {@link TrainerFactory#isSupportSequence(Map)} instead.
    */
   public static boolean isSequenceTraining(Map<String, String> trainParams) {
-    return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
+	return TrainerFactory.isSupportSequence(trainParams);
   }
   
+  /**
+   * @deprecated Use {@link TrainerFactory#getSequenceTrainer(Map, Map)} to get an
+   *             {@link SequenceTrainer} instead.
+   */
   public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
       Map<String, String> reportMap) throws IOException {
     

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java?rev=1521519&r1=1521518&r2=1521519&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java Tue Sep 10 15:07:35 2013
@@ -28,6 +28,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 
+import opennlp.tools.ml.maxent.GIS;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.GenericModelWriter;
 import opennlp.tools.ml.model.MaxentModel;
@@ -141,7 +142,7 @@ public final class ModelUtil {
    */
   public static TrainingParameters createTrainingParameters(int iterations, int cutoff) {
     TrainingParameters mlParams = new TrainingParameters();
-    mlParams.put(TrainingParameters.ALGORITHM_PARAM, TrainUtil.MAXENT_VALUE);
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
     mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(iterations));
     mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(cutoff));
     

Added: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java?rev=1521519&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java Tue Sep 10 15:07:35 2013
@@ -0,0 +1,15 @@
+package opennlp.tools.ml;
+
+import java.io.IOException;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.EventStream;
+
+public class MockEventTrainer implements EventTrainer {
+
+  public AbstractModel train(EventStream events) throws IOException {
+    // TODO Auto-generated method stub
+    return null;
+  }
+
+}

Propchange: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Added: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java?rev=1521519&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java Tue Sep 10 15:07:35 2013
@@ -0,0 +1,14 @@
+package opennlp.tools.ml;
+
+import java.io.IOException;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.SequenceStream;
+
+public class MockSequenceTrainer implements SequenceTrainer {
+
+  public AbstractModel train(SequenceStream events) throws IOException {
+    return null;
+  }
+
+}

Propchange: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Added: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java?rev=1521519&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java (added)
+++ opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java Tue Sep 10 15:07:35 2013
@@ -0,0 +1,45 @@
+package opennlp.tools.ml;
+
+import static org.junit.Assert.*;
+import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.util.TrainingParameters;
+
+import org.junit.Before;
+import org.junit.Test;
+
+public class TrainerFactoryTest {
+  
+  private TrainingParameters mlParams;
+
+  @Before
+  public void setup() {
+    mlParams = new TrainingParameters();
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(10));
+    mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5));
+  }
+
+  @Test
+  public void testBuiltInValid() {
+    assertTrue(TrainerFactory.isValid(mlParams.getSettings()));
+  }
+
+  @Test
+  public void testSequenceTrainerValid() {
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, MockSequenceTrainer.class.getCanonicalName());
+    assertTrue(TrainerFactory.isValid(mlParams.getSettings()));
+  }
+
+  @Test
+  public void testEventTrainerValid() {
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, MockEventTrainer.class.getCanonicalName());
+    assertTrue(TrainerFactory.isValid(mlParams.getSettings()));
+  }
+
+  @Test
+  public void testInvalidTrainer() {
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, "xyz");
+    assertFalse(TrainerFactory.isValid(mlParams.getSettings()));
+  }
+
+}

Propchange: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java?rev=1521519&r1=1521518&r2=1521519&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java (original)
+++ opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java Tue Sep 10 15:07:35 2013
@@ -24,6 +24,8 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.AbstractTrainer;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.TrainUtil;
 import opennlp.tools.ml.model.TwoPassDataIndexer;
@@ -56,10 +58,10 @@ public class MaxentPrepAttachTest {
   public void testMaxentOnPrepAttachDataWithParams() throws IOException {
     
     Map<String, String> trainParams = new HashMap<String, String>();
-    trainParams.put(TrainUtil.ALGORITHM_PARAM, TrainUtil.MAXENT_VALUE);
-    trainParams.put(TrainUtil.DATA_INDEXER_PARAM,
-        TrainUtil.DATA_INDEXER_TWO_PASS_VALUE);
-    trainParams.put(TrainUtil.CUTOFF_PARAM, Integer.toString(1));
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    trainParams.put(AbstractEventTrainer.DATA_INDEXER_PARAM,
+        AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE);
+    trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1));
     
     AbstractModel model = TrainUtil.train(createTrainingStream(), trainParams, null);
     
@@ -70,7 +72,7 @@ public class MaxentPrepAttachTest {
   public void testMaxentOnPrepAttachDataWithParamsDefault() throws IOException {
     
     Map<String, String> trainParams = new HashMap<String, String>();
-    trainParams.put(TrainUtil.ALGORITHM_PARAM, TrainUtil.MAXENT_VALUE);
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
     
     AbstractModel model = TrainUtil.train(createTrainingStream(), trainParams, null);
     

Modified: opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.java?rev=1521519&r1=1521518&r2=1521519&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.java (original)
+++ opennlp/trunk/opennlp-tools/src/test/java/opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.java Tue Sep 10 15:07:35 2013
@@ -24,6 +24,7 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+import opennlp.tools.ml.AbstractTrainer;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.TrainUtil;
 import opennlp.tools.ml.model.TwoPassDataIndexer;
@@ -48,8 +49,8 @@ public class PerceptronPrepAttachTest {
   public void testPerceptronOnPrepAttachDataWithSkippedAveraging() throws IOException {
     
     Map<String, String> trainParams = new HashMap<String, String>();
-    trainParams.put(TrainUtil.ALGORITHM_PARAM, TrainUtil.PERCEPTRON_VALUE);
-    trainParams.put(TrainUtil.CUTOFF_PARAM, Integer.toString(1));
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, PerceptronTrainer.PERCEPTRON_VALUE);
+    trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1));
     trainParams.put("UseSkippedAveraging", Boolean.toString(true));
     
     AbstractModel model = TrainUtil.train(createTrainingStream(), trainParams, null);
@@ -61,9 +62,9 @@ public class PerceptronPrepAttachTest {
   public void testPerceptronOnPrepAttachDataWithTolerance() throws IOException {
     
     Map<String, String> trainParams = new HashMap<String, String>();
-    trainParams.put(TrainUtil.ALGORITHM_PARAM, TrainUtil.PERCEPTRON_VALUE);
-    trainParams.put(TrainUtil.CUTOFF_PARAM, Integer.toString(1));
-    trainParams.put(TrainUtil.ITERATIONS_PARAM, Integer.toString(500));
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, PerceptronTrainer.PERCEPTRON_VALUE);
+    trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1));
+    trainParams.put(AbstractTrainer.ITERATIONS_PARAM, Integer.toString(500));
     trainParams.put("Tolerance", Double.toString(0.0001d));
     
     AbstractModel model = TrainUtil.train(createTrainingStream(), trainParams, null);
@@ -75,9 +76,9 @@ public class PerceptronPrepAttachTest {
   public void testPerceptronOnPrepAttachDataWithStepSizeDecrease() throws IOException {
     
     Map<String, String> trainParams = new HashMap<String, String>();
-    trainParams.put(TrainUtil.ALGORITHM_PARAM, TrainUtil.PERCEPTRON_VALUE);
-    trainParams.put(TrainUtil.CUTOFF_PARAM, Integer.toString(1));
-    trainParams.put(TrainUtil.ITERATIONS_PARAM, Integer.toString(500));
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, PerceptronTrainer.PERCEPTRON_VALUE);
+    trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1));
+    trainParams.put(AbstractTrainer.ITERATIONS_PARAM, Integer.toString(500));
     trainParams.put("StepSizeDecrease", Double.toString(0.06d));
     
     AbstractModel model = TrainUtil.train(createTrainingStream(), trainParams, null);