You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by co...@apache.org on 2013/06/07 00:11:39 UTC
svn commit: r1490460 - in
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools:
ml/TrainerFactory.java ml/model/TrainUtil.java util/TrainingParameters.java
Author: colen
Date: Thu Jun 6 22:11:39 2013
New Revision: 1490460
URL: http://svn.apache.org/r1490460
Log:
OPENNLP-581 First proposal for a TrainerFactory. Again, I only changed the TrainUtil to avoid changing many classes. I still not happy with the implementation, but would like feedback
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (with props)
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1490460&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jun 6 22:11:39 2013
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.lang.reflect.Constructor;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
+import opennlp.tools.ml.perceptron.PerceptronTrainer;
+import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
+
+public class TrainerFactory {
+
+ // built-in trainers
+ private static final Map<String, Class> BUILTIN_TRAINERS;
+
+ static {
+ Map<String, Class> _trainers = new HashMap<String, Class>();
+ _trainers.put(GIS.MAXENT_VALUE, GIS.class);
+ _trainers.put(QNTrainer.MAXENT_QN_VALUE, QNTrainer.class);
+ _trainers.put(PerceptronTrainer.PERCEPTRON_VALUE, PerceptronTrainer.class);
+ _trainers.put(SimplePerceptronSequenceTrainer.PERCEPTRON_SEQUENCE_VALUE,
+ SimplePerceptronSequenceTrainer.class);
+
+ BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
+ }
+
+ public static boolean isSupportEvent(Map<String, String> trainParams) {
+ if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
+ if(EventTrainer.EVENT_VALUE.equals(trainParams
+ .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
+ return true;
+ }
+ return false;
+ } else {
+ return true; // default to event train
+ }
+ }
+
+ public static boolean isSupportSequence(Map<String, String> trainParams) {
+ if (SequenceTrainer.SEQUENCE_VALUE.equals(trainParams
+ .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
+ return true;
+ }
+ return false;
+ }
+
+ public static SequenceTrainer getSequenceTrainer(
+ Map<String, String> trainParams, Map<String, String> reportMap) {
+ String trainerType = getTrainerType(trainParams);
+ if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+ return TrainerFactory.<SequenceTrainer> create(
+ BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+ } else {
+ return TrainerFactory.<SequenceTrainer> create(trainerType, trainParams,
+ reportMap);
+ }
+ }
+
+ public static EventTrainer getEventTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ String trainerType = getTrainerType(trainParams);
+ if(trainerType == null) {
+ // default to MAXENT
+ return new GIS(trainParams, reportMap);
+ }
+
+ if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+ return TrainerFactory.<EventTrainer> create(
+ BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+ } else {
+ return TrainerFactory.<EventTrainer> create(trainerType, trainParams,
+ reportMap);
+ }
+ }
+
+ private static String getTrainerType(Map<String, String> trainParams) {
+ return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+ }
+
+ private static <T> T create(String className,
+ Map<String, String> trainParams, Map<String, String> reportMap) {
+ T theFactory = null;
+
+ try {
+ // TODO: won't work in OSGi!
+ Class<T> trainerClass = (Class<T>) Class.forName(className);
+ theFactory = create(trainerClass, trainParams, reportMap);
+ } catch (Exception e) {
+ String msg = "Could not instantiate the " + className
+ + ". The initialization throw an exception.";
+ System.err.println(msg);
+ e.printStackTrace();
+ throw new IllegalArgumentException(msg, e);
+ }
+ return theFactory;
+ }
+
+ private static <T> T create(Class<T> trainerClass,
+ Map<String, String> trainParams, Map<String, String> reportMap) {
+ T theTrainer = null;
+ if (trainerClass != null) {
+ try {
+ Constructor<T> contructor = trainerClass.getConstructor(Map.class,
+ Map.class);
+ theTrainer = contructor.newInstance(trainParams, reportMap);
+ } catch (Exception e) {
+ String msg = "Could not instantiate the "
+ + trainerClass.getCanonicalName()
+ + ". The initialization throw an exception.";
+ System.err.println(msg);
+ e.printStackTrace();
+ throw new IllegalArgumentException(msg, e);
+ }
+ }
+ return theTrainer;
+ }
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java?rev=1490460&r1=1490459&r2=1490460&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java Thu Jun 6 22:11:39 2013
@@ -23,6 +23,8 @@ import java.io.IOException;
import java.util.Map;
import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.SequenceTrainer;
+import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.maxent.GIS;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
@@ -37,33 +39,14 @@ public class TrainUtil {
public static final String PERCEPTRON_VALUE = "PERCEPTRON";
public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
-
public static final String CUTOFF_PARAM = "Cutoff";
- private static final int CUTOFF_DEFAULT = 5;
public static final String ITERATIONS_PARAM = "Iterations";
- private static final int ITERATIONS_DEFAULT = 100;
public static final String DATA_INDEXER_PARAM = "DataIndexer";
public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
-
- private static String getStringParam(Map<String, String> trainParams, String key,
- String defaultValue, Map<String, String> reportMap) {
-
- String valueString = trainParams.get(key);
-
- if (valueString == null)
- valueString = defaultValue;
-
- if (reportMap != null)
- reportMap.put(key, valueString);
-
- return valueString;
- }
-
-
public static boolean isValid(Map<String, String> trainParams) {
// TODO: Need to validate all parameters correctly ... error prone?!
@@ -108,30 +91,10 @@ public class TrainUtil {
public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap)
throws IOException {
- if (!isValid(trainParams))
- throw new IllegalArgumentException("trainParams are not valid!");
-
- if(isSequenceTraining(trainParams))
- throw new IllegalArgumentException("sequence training is not supported by this method!");
-
- String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
-
- EventTrainer trainer;
- if(PERCEPTRON_VALUE.equals(algorithmName)) {
-
- trainer = new PerceptronTrainer(trainParams, reportMap);
-
- } else if(MAXENT_VALUE.equals(algorithmName)) {
-
- trainer = new GIS(trainParams, reportMap);
-
- } else if(MAXENT_QN_VALUE.equals(algorithmName)) {
-
- trainer = new QNTrainer(trainParams, reportMap);
-
- } else {
- trainer = new GIS(trainParams, reportMap); // default to maxent?
+ if(!TrainerFactory.isSupportEvent(trainParams)) {
+ throw new IllegalArgumentException("EventTrain is not supported");
}
+ EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, reportMap);
return trainer.train(events);
}
@@ -147,8 +110,11 @@ public class TrainUtil {
public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
Map<String, String> reportMap) throws IOException {
- SimplePerceptronSequenceTrainer trainer = new SimplePerceptronSequenceTrainer(
- trainParams, reportMap);
+ if(!TrainerFactory.isSupportSequence(trainParams)) {
+ throw new IllegalArgumentException("EventTrain is not supported");
+ }
+ SequenceTrainer trainer = TrainerFactory.getSequenceTrainer(trainParams, reportMap);
+
return trainer.train(events);
}
}
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java?rev=1490460&r1=1490459&r2=1490460&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java Thu Jun 6 22:11:39 2013
@@ -25,9 +25,13 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
+import opennlp.tools.ml.EventTrainer;
+
public class TrainingParameters {
+ // TODO: are them duplicated?
public static final String ALGORITHM_PARAM = "Algorithm";
+ public static final String TRAINER_TYPE_PARAM = "TrainerType";
public static final String ITERATIONS_PARAM = "Iterations";
public static final String CUTOFF_PARAM = "Cutoff";
@@ -144,6 +148,7 @@ public class TrainingParameters {
public static final TrainingParameters defaultParams() {
TrainingParameters mlParams = new TrainingParameters();
mlParams.put(TrainingParameters.ALGORITHM_PARAM, "MAXENT");
+ mlParams.put(TrainingParameters.TRAINER_TYPE_PARAM, EventTrainer.EVENT_VALUE);
mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(100));
mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5));