You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by co...@apache.org on 2013/06/05 19:24:56 UTC
svn commit: r1489970 - in
/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml: ./ maxent/
maxent/quasinewton/ model/ perceptron/
Author: colen
Date: Wed Jun 5 17:24:56 2013
New Revision: 1489970
URL: http://svn.apache.org/r1489970
Log:
OPENNLP-581 Added Trainer, EventTrainer and SequenceTrainer interfaces and some abstract implementations. Modified existing trainers to extend the abstract classes
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java (with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java (with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java (with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java (with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java (with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java (with props)
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.EventStream;
+import opennlp.tools.ml.model.HashSumEventStream;
+import opennlp.tools.ml.model.OnePassDataIndexer;
+import opennlp.tools.ml.model.TwoPassDataIndexer;
+
+public abstract class AbstractEventTrainer extends AbstractTrainer implements
+ EventTrainer {
+
+ public static final String DATA_INDEXER_PARAM = "DataIndexer";
+ public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
+ public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
+
+ public AbstractEventTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ public boolean isSequenceTraining() {
+ return false;
+ }
+
+ public boolean isEventTraining() {
+ return true;
+ }
+
+ @Override
+ public boolean isValid() {
+ if (!super.isValid()) {
+ return false;
+ }
+
+ String dataIndexer = getStringParam(DATA_INDEXER_PARAM,
+ DATA_INDEXER_TWO_PASS_VALUE);
+
+ if (dataIndexer != null) {
+ if (!(DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexer) || DATA_INDEXER_TWO_PASS_VALUE
+ .equals(dataIndexer))) {
+ return false;
+ }
+ }
+ // TODO: Check data indexing ...
+
+ return true;
+ }
+
+ public abstract boolean isSortAndMerge();
+
+ public DataIndexer getDataIndexer(EventStream events) throws IOException {
+
+ String dataIndexerName = getStringParam(DATA_INDEXER_PARAM,
+ DATA_INDEXER_TWO_PASS_VALUE);
+
+ int cutoff = getCutoff();
+ boolean sortAndMerge = isSortAndMerge();
+ DataIndexer indexer = null;
+
+ if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
+ indexer = new OnePassDataIndexer(events, cutoff, sortAndMerge);
+ } else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
+ indexer = new TwoPassDataIndexer(events, cutoff, sortAndMerge);
+ } else {
+ throw new IllegalStateException("Unexpected data indexer name: "
+ + dataIndexerName);
+ }
+ return indexer;
+ }
+
+ public abstract AbstractModel doTrain(DataIndexer indexer) throws IOException;
+
+ public final AbstractModel train(EventStream events) throws IOException {
+ HashSumEventStream hses = new HashSumEventStream(events);
+ DataIndexer indexer = getDataIndexer(events);
+
+ AbstractModel model = doTrain(indexer);
+
+ addToReport("Training-Eventhash", hses.calculateHashSum().toString(16));
+ return model;
+ }
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.util.Map;
+
+public abstract class AbstractSequenceTrainer extends AbstractTrainer implements
+ SequenceTrainer {
+
+ public AbstractSequenceTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ public boolean isSequenceTraining() {
+ return true;
+ }
+
+ public boolean isEventTraining() {
+ return false;
+ }
+
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.util.Map;
+
+import opennlp.tools.ml.maxent.GIS;
+
+public abstract class AbstractTrainer implements Trainer {
+
+ public static final String ALGORITHM_PARAM = "Algorithm";
+
+ public static final String CUTOFF_PARAM = "Cutoff";
+ public static final int CUTOFF_DEFAULT = 5;
+
+ public static final String ITERATIONS_PARAM = "Iterations";
+ public static final int ITERATIONS_DEFAULT = 100;
+
+ private final Map<String, String> trainParams;
+ private final Map<String, String> reportMap;
+
+ public AbstractTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) throws IllegalArgumentException {
+ this.trainParams = trainParams;
+ this.reportMap = reportMap;
+ }
+
+ public String getAlgorithm() {
+ return getStringParam(ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+ }
+
+ public int getCutoff() {
+ return getIntParam(CUTOFF_PARAM, CUTOFF_DEFAULT);
+ }
+
+ public int getIterations() {
+ return getIntParam(ITERATIONS_PARAM, ITERATIONS_DEFAULT);
+ }
+
+ protected String getStringParam(String key, String defaultValue) {
+
+ String valueString = trainParams.get(key);
+
+ if (valueString == null)
+ valueString = defaultValue;
+
+ if (reportMap != null)
+ reportMap.put(key, valueString);
+
+ return valueString;
+ }
+
+ protected int getIntParam(String key, int defaultValue) {
+
+ String valueString = trainParams.get(key);
+
+ if (valueString != null)
+ return Integer.parseInt(valueString);
+ else
+ return defaultValue;
+ }
+
+ protected double getDoubleParam(String key, double defaultValue) {
+
+ String valueString = trainParams.get(key);
+
+ if (valueString != null)
+ return Double.parseDouble(valueString);
+ else
+ return defaultValue;
+ }
+
+ protected boolean getBooleanParam(String key, boolean defaultValue) {
+
+ String valueString = trainParams.get(key);
+
+ if (valueString != null)
+ return Boolean.parseBoolean(valueString);
+ else
+ return defaultValue;
+ }
+
+ protected void addToReport(String key, String value) {
+ if (reportMap != null) {
+ reportMap.put(key, value);
+ }
+ }
+
+ public boolean isValid() {
+
+ // TODO: Need to validate all parameters correctly ... error prone?!
+
+ // should validate if algorithm is set? What about the Parser?
+
+ try {
+ String cutoffString = trainParams.get(CUTOFF_PARAM);
+ if (cutoffString != null)
+ Integer.parseInt(cutoffString);
+
+ String iterationsString = trainParams.get(ITERATIONS_PARAM);
+ if (iterationsString != null)
+ Integer.parseInt(iterationsString);
+ } catch (NumberFormatException e) {
+ return false;
+ }
+
+ return true;
+ }
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.io.IOException;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.EventStream;
+
+public interface EventTrainer extends Trainer {
+
+ public AbstractModel train(EventStream events) throws IOException;
+
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/EventTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.io.IOException;
+
+import opennlp.tools.ml.model.AbstractModel;
+import opennlp.tools.ml.model.SequenceStream;
+
+public interface SequenceTrainer extends Trainer {
+
+ public AbstractModel train(SequenceStream events) throws IOException;
+
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/SequenceTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java?rev=1489970&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java Wed Jun 5 17:24:56 2013
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+public interface Trainer {
+
+ public boolean isSequenceTraining();
+
+ public boolean isEventTraining();
+
+}
Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/Trainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java?rev=1489970&r1=1489969&r2=1489970&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java Wed Jun 5 17:24:56 2013
@@ -20,7 +20,11 @@
package opennlp.tools.ml.maxent;
import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EventStream;
import opennlp.tools.ml.model.Prior;
@@ -30,7 +34,10 @@ import opennlp.tools.ml.model.UniformPri
* A Factory class which uses instances of GISTrainer to create and train
* GISModels.
*/
-public class GIS {
+public class GIS extends AbstractEventTrainer {
+
+ public static final String MAXENT_VALUE = "MAXENT";
+
/**
* Set this to false if you don't want messages about the progress of model
* training displayed. Alternately, you can use the overloaded version of
@@ -45,6 +52,53 @@ public class GIS {
*/
public static double SMOOTHING_OBSERVATION = 0.1;
+ // >> members related to AbstractEventTrainer
+ public GIS(Map<String, String> trainParams, Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ public GIS() {
+ super(Collections.<String, String> emptyMap(), Collections
+ .<String, String> emptyMap());
+ }
+
+ public boolean isValid() {
+
+ if (!super.isValid()) {
+ return false;
+ }
+
+ String algorithmName = getAlgorithm();
+
+ if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName))) {
+ return false;
+ }
+
+ return true;
+ }
+
+ public boolean isSortAndMerge() {
+ return true;
+ }
+
+ public AbstractModel doTrain(DataIndexer indexer) throws IOException {
+ if (!isValid()) {
+ throw new IllegalArgumentException("trainParams are not valid!");
+ }
+
+ int iterations = getIterations();
+
+ AbstractModel model;
+
+ int threads = getIntParam("Threads", 1);
+
+ model = trainModel(iterations, indexer, true, false, null, 0, threads);
+
+ return model;
+ }
+
+ // << members related to AbstractEventTrainer
+
/**
* Train a model using the GIS algorithm, assuming 100 iterations and no
* cutoff.
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java?rev=1489970&r1=1489969&r2=1489970&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java Wed Jun 5 17:24:56 2013
@@ -18,14 +18,22 @@
*/
package opennlp.tools.ml.maxent.quasinewton;
+import java.io.IOException;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
/**
* maxent model trainer using l-bfgs algorithm.
*/
-public class QNTrainer {
+public class QNTrainer extends AbstractEventTrainer {
+
+ public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
+
// constants for optimization.
private static final double CONVERGE_TOLERANCE = 1.0E-10;
private static final int MAX_M = 15;
@@ -61,6 +69,9 @@ public class QNTrainer {
}
public QNTrainer(int m, int maxFctEval, boolean verbose) {
+ super(Collections.<String, String> emptyMap(), Collections
+ .<String, String> emptyMap());
+
this.verbose = verbose;
if (m > MAX_M) {
this.m = MAX_M;
@@ -76,6 +87,62 @@ public class QNTrainer {
}
}
+ // >> members related to AbstractEventTrainer
+ public QNTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+
+ int m = getIntParam("numOfUpdates", DEFAULT_M);
+ int maxFctEval = getIntParam("maxFctEval", DEFAULT_MAX_FCT_EVAL);
+
+ this.verbose = true;
+ if (m > MAX_M) {
+ this.m = MAX_M;
+ } else {
+ this.m = m;
+ }
+ if (maxFctEval < 0) {
+ this.maxFctEval = DEFAULT_MAX_FCT_EVAL;
+ } else if (maxFctEval > MAX_FCT_EVAL) {
+ this.maxFctEval = MAX_FCT_EVAL;
+ } else {
+ this.maxFctEval = maxFctEval;
+ }
+ }
+
+ public boolean isValid() {
+
+ if (!super.isValid()) {
+ return false;
+ }
+
+ String algorithmName = getAlgorithm();
+
+ if (algorithmName != null && !(MAXENT_QN_VALUE.equals(algorithmName))) {
+ return false;
+ }
+
+ return true;
+ }
+
+ public boolean isSortAndMerge() {
+ return true;
+ }
+
+ public AbstractModel doTrain(DataIndexer indexer) throws IOException {
+ if (!isValid()) {
+ throw new IllegalArgumentException("trainParams are not valid!");
+ }
+
+ AbstractModel model;
+
+ model = trainModel(indexer);
+
+ return model;
+ }
+
+ // << members related to AbstractEventTrainer
+
public QNModel trainModel(DataIndexer indexer) {
LogLikelihoodFunction objectiveFunction = generateFunction(indexer);
this.dimension = objectiveFunction.getDomainDimension();
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java?rev=1489970&r1=1489969&r2=1489970&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java Wed Jun 5 17:24:56 2013
@@ -22,6 +22,8 @@ package opennlp.tools.ml.model;
import java.io.IOException;
import java.util.Map;
+import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.maxent.GIS;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
@@ -61,38 +63,6 @@ public class TrainUtil {
return valueString;
}
- private static int getIntParam(Map<String, String> trainParams, String key,
- int defaultValue, Map<String, String> reportMap) {
-
- String valueString = trainParams.get(key);
-
- if (valueString != null)
- return Integer.parseInt(valueString);
- else
- return defaultValue;
- }
-
- private static double getDoubleParam(Map<String, String> trainParams, String key,
- double defaultValue, Map<String, String> reportMap) {
-
- String valueString = trainParams.get(key);
-
- if (valueString != null)
- return Double.parseDouble(valueString);
- else
- return defaultValue;
- }
-
- private static boolean getBooleanParam(Map<String, String> trainParams, String key,
- boolean defaultValue, Map<String, String> reportMap) {
-
- String valueString = trainParams.get(key);
-
- if (valueString != null)
- return Boolean.parseBoolean(valueString);
- else
- return defaultValue;
- }
public static boolean isValid(Map<String, String> trainParams) {
@@ -146,81 +116,24 @@ public class TrainUtil {
String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
- int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
-
- int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
-
- boolean sortAndMerge;
-
- if (MAXENT_VALUE.equals(algorithmName) || MAXENT_QN_VALUE.equals(algorithmName))
- sortAndMerge = true;
- else if (PERCEPTRON_VALUE.equals(algorithmName))
- sortAndMerge = false;
- else
- throw new IllegalStateException("Unexpected algorithm name: " + algorithmName);
-
- HashSumEventStream hses = new HashSumEventStream(events);
-
- String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
- DATA_INDEXER_TWO_PASS_VALUE, reportMap);
-
- DataIndexer indexer = null;
-
- if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
- indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
- }
- else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
- indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
- }
- else {
- throw new IllegalStateException("Unexpected data indexer name: " + dataIndexerName);
- }
-
- AbstractModel model;
- if (MAXENT_VALUE.equals(algorithmName)) {
-
- int threads = getIntParam(trainParams, "Threads", 1, reportMap);
-
- model = opennlp.tools.ml.maxent.GIS.trainModel(iterations, indexer,
- true, false, null, 0, threads);
- }
- else if (MAXENT_QN_VALUE.equals(algorithmName)) {
- int m = getIntParam(trainParams, "numOfUpdates", QNTrainer.DEFAULT_M, reportMap);
- int maxFctEval = getIntParam(trainParams, "maxFctEval", QNTrainer.DEFAULT_MAX_FCT_EVAL, reportMap);
- model = new QNTrainer(m, maxFctEval, true).trainModel(indexer);
- }
- else if (PERCEPTRON_VALUE.equals(algorithmName)) {
- boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
-
- boolean useSkippedAveraging = getBooleanParam(trainParams, "UseSkippedAveraging", false, reportMap);
-
- // overwrite otherwise it might not work
- if (useSkippedAveraging)
- useAverage = true;
+ EventTrainer trainer;
+ if(PERCEPTRON_VALUE.equals(algorithmName)) {
- double stepSizeDecrease = getDoubleParam(trainParams, "StepSizeDecrease", 0, reportMap);
+ trainer = new PerceptronTrainer(trainParams, reportMap);
- double tolerance = getDoubleParam(trainParams, "Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT, reportMap);
+ } else if(MAXENT_VALUE.equals(algorithmName)) {
- opennlp.tools.ml.perceptron.PerceptronTrainer perceptronTrainer = new opennlp.tools.ml.perceptron.PerceptronTrainer();
- perceptronTrainer.setSkippedAveraging(useSkippedAveraging);
+ trainer = new GIS(trainParams, reportMap);
- if (stepSizeDecrease > 0)
- perceptronTrainer.setStepSizeDecrease(stepSizeDecrease);
+ } else if(MAXENT_QN_VALUE.equals(algorithmName)) {
- perceptronTrainer.setTolerance(tolerance);
-
- model = perceptronTrainer.trainModel(
- iterations, indexer, cutoff, useAverage);
- }
- else {
- throw new IllegalStateException("Algorithm not supported: " + algorithmName);
- }
+ trainer = new QNTrainer(trainParams, reportMap);
- if (reportMap != null)
- reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
+ } else {
+ trainer = new GIS(trainParams, reportMap); // default to maxent?
+ }
- return model;
+ return trainer.train(events);
}
/**
@@ -234,18 +147,8 @@ public class TrainUtil {
public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
Map<String, String> reportMap) throws IOException {
- if (!isValid(trainParams))
- throw new IllegalArgumentException("trainParams are not valid!");
-
- if (!isSequenceTraining(trainParams))
- throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
-
- int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
- int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
-
- boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
-
- return new SimplePerceptronSequenceTrainer().trainModel(
- iterations, events, cutoff,useAverage);
+ SimplePerceptronSequenceTrainer trainer = new SimplePerceptronSequenceTrainer(
+ trainParams, reportMap);
+ return trainer.train(events);
}
}
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java?rev=1489970&r1=1489969&r2=1489970&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java Wed Jun 5 17:24:56 2013
@@ -19,6 +19,11 @@
package opennlp.tools.ml.perceptron;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+
+import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
@@ -32,8 +37,9 @@ import opennlp.tools.ml.model.MutableCon
* with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
*
*/
-public class PerceptronTrainer {
+public class PerceptronTrainer extends AbstractEventTrainer {
+ public static final String PERCEPTRON_VALUE = "PERCEPTRON";
public static final double TOLERANCE_DEFAULT = .00001;
/** Number of unique events which occurred in the event set. */
@@ -77,6 +83,73 @@ public class PerceptronTrainer {
private boolean useSkippedlAveraging;
+ // >> members related to AbstractSequenceTrainer
+ public PerceptronTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ public PerceptronTrainer() {
+ super(Collections.<String, String> emptyMap(), Collections
+ .<String, String> emptyMap());
+ }
+
+ public boolean isValid() {
+
+ if (!super.isValid()) {
+ return false;
+ }
+
+ String algorithmName = getAlgorithm();
+
+ if (algorithmName != null && !(PERCEPTRON_VALUE.equals(algorithmName))) {
+ return false;
+ }
+
+ return true;
+ }
+
+ public boolean isSortAndMerge() {
+ return false;
+ }
+
+ public AbstractModel doTrain(DataIndexer indexer) throws IOException {
+ if (!isValid()) {
+ throw new IllegalArgumentException("trainParams are not valid!");
+ }
+
+ int iterations = getIterations();
+ int cutoff = getCutoff();
+
+ AbstractModel model;
+
+ boolean useAverage = getBooleanParam("UseAverage", true);
+
+ boolean useSkippedAveraging = getBooleanParam("UseSkippedAveraging", false);
+
+ // overwrite otherwise it might not work
+ if (useSkippedAveraging)
+ useAverage = true;
+
+ double stepSizeDecrease = getDoubleParam("StepSizeDecrease", 0);
+
+ double tolerance = getDoubleParam("Tolerance",
+ PerceptronTrainer.TOLERANCE_DEFAULT);
+
+ this.setSkippedAveraging(useSkippedAveraging);
+
+ if (stepSizeDecrease > 0)
+ this.setStepSizeDecrease(stepSizeDecrease);
+
+ this.setTolerance(tolerance);
+
+ model = this.trainModel(iterations, indexer, cutoff, useAverage);
+
+ return model;
+ }
+
+ // << members related to AbstractSequenceTrainer
+
/**
* Specifies the tolerance. If the change in training set accuracy
* is less than this, stop iterating.
Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java?rev=1489970&r1=1489969&r2=1489970&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java Wed Jun 5 17:24:56 2013
@@ -20,9 +20,11 @@
package opennlp.tools.ml.perceptron;
import java.io.IOException;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import opennlp.tools.ml.AbstractSequenceTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
@@ -42,7 +44,9 @@ import opennlp.tools.ml.model.SequenceSt
* Specifically only updates are applied to tokens which were incorrectly tagged by a sequence tagger
* rather than to all feature across the sequence which differ from the training sequence.
*/
-public class SimplePerceptronSequenceTrainer {
+public class SimplePerceptronSequenceTrainer extends AbstractSequenceTrainer {
+
+ public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
private boolean printMessages = true;
private int iterations;
@@ -81,6 +85,48 @@ public class SimplePerceptronSequenceTra
private String[] predLabels;
int numSequences;
+ // >> members related to AbstractSequenceTrainer
+ public SimplePerceptronSequenceTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ public SimplePerceptronSequenceTrainer() {
+ super(Collections.<String, String> emptyMap(), Collections
+ .<String, String> emptyMap());
+ }
+
+ public boolean isValid() {
+
+ if (!super.isValid()) {
+ return false;
+ }
+
+ String algorithmName = getAlgorithm();
+
+ if (algorithmName != null
+ && !(PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
+ return false;
+ }
+
+ return true;
+ }
+
+ public AbstractModel train(SequenceStream events) throws IOException {
+ if (!isValid()) {
+ throw new IllegalArgumentException("trainParams are not valid!");
+ }
+
+ int iterations = getIterations();
+ int cutoff = getCutoff();
+
+ boolean useAverage = getBooleanParam("UseAverage", true);
+
+ return trainModel(iterations, events, cutoff, useAverage);
+ }
+
+ // << members related to AbstractSequenceTrainer
+
public AbstractModel trainModel(int iterations, SequenceStream sequenceStream, int cutoff, boolean useAverage) throws IOException {
this.iterations = iterations;
this.sequenceStream = sequenceStream;