You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2011/05/19 16:37:31 UTC
svn commit: r1124852 - in
/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model:
HashSumEventStream.java TrainUtil.java
Author: joern
Date: Thu May 19 14:37:31 2011
New Revision: 1124852
URL: http://svn.apache.org/viewvc?rev=1124852&view=rev
Log:
OPENNLP-175 Updated to also report training parameters
Added:
incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java (with props)
Modified:
incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
Added: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java?rev=1124852&view=auto
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java (added)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java Thu May 19 14:37:31 2011
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreemnets. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.model;
+
+import java.io.IOException;
+import java.io.UnsupportedEncodingException;
+import java.math.BigInteger;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+
+import opennlp.model.Event;
+import opennlp.model.EventStream;
+
+public class HashSumEventStream implements EventStream {
+
+ private final EventStream eventStream;
+
+ private MessageDigest digest;
+
+ public HashSumEventStream(EventStream eventStream) {
+ this.eventStream = eventStream;
+
+ try {
+ digest = MessageDigest.getInstance("MD5");
+ } catch (NoSuchAlgorithmException e) {
+ // should never happen, does all java runtimes have md5 ?!
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public boolean hasNext() throws IOException {
+ return eventStream.hasNext();
+ }
+
+ public Event next() throws IOException {
+
+ Event event = eventStream.next();
+
+ try {
+ digest.update(event.toString().getBytes("UTF-8"));
+ }
+ catch (UnsupportedEncodingException e) {
+ throw new IllegalStateException("UTF-8 encoding is not available!");
+ }
+
+ return event;
+ }
+
+ /**
+ * Calculates the hash sum of the stream. The method must be
+ * called after the stream is completely consumed.
+ *
+ * @return the hash sum
+ * @throws IllegalStateException if the stream is not consumed completely,
+ * completely means that hasNext() returns false
+ */
+ public BigInteger calculateHashSum() {
+
+// if (hasNext())
+// throw new IllegalStateException("stream must be consumed completely!");
+
+ return new BigInteger(1, digest.digest());
+ }
+
+ public void remove() {
+ }
+}
Propchange: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java?rev=1124852&r1=1124851&r2=1124852&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java Thu May 19 14:37:31 2011
@@ -20,6 +20,7 @@
package opennlp.model;
import java.io.IOException;
+import java.util.HashMap;
import java.util.Map;
import opennlp.perceptron.SimplePerceptronSequenceTrainer;
@@ -34,33 +35,64 @@ public class TrainUtil {
public static final String CUTOFF_PARAM = "Cutoff";
- public static final String ITERATIONS_PARAM = "Iterations";
+ private static final int CUTOFF_DEFAULT = 5;
+ public static final String ITERATIONS_PARAM = "Iterations";
private static final int ITERATIONS_DEFAULT = 100;
- private static final int CUTOFF_DEFAULT = 5;
+ public static final String DATA_INDEXER_PARAM = "DataIndexer";
+ public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
+ public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
- private static int getIntParam(Map<String, String> trainParams, String key,
- int defaultValue) {
-
+
+ private static String getStringParam(Map<String, String> trainParams, String key,
+ String defaultValue, Map<String, String> reportMap) {
+
String valueString = trainParams.get(key);
+
+ if (valueString == null)
+ valueString = defaultValue;
+ if (reportMap != null)
+ reportMap.put(key, valueString);
+
+ return valueString;
+ }
+
+ private static int getIntParam(Map<String, String> trainParams, String key,
+ int defaultValue, Map<String, String> reportMap) {
+
+ String valueString = trainParams.get(key);
+
if (valueString != null)
return Integer.parseInt(valueString);
else
return defaultValue;
}
+ private static boolean getBooleanParam(Map<String, String> trainParams, String key,
+ boolean defaultValue, Map<String, String> reportMap) {
+
+ String valueString = trainParams.get(key);
+
+ if (valueString != null)
+ return Boolean.parseBoolean(valueString);
+ else
+ return defaultValue;
+ }
+
public static boolean isValid(Map<String, String> trainParams) {
+
+ // TODO: Need to validate all parameters correctly ... error prone?!
String algorithmName = trainParams.get(ALGORITHM_PARAM);
-
- if (!(MAXENT_VALUE.equals(algorithmName) ||
+
+ if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
PERCEPTRON_VALUE.equals(algorithmName) ||
PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
return false;
}
-
+
try {
String cutoffString = trainParams.get(CUTOFF_PARAM);
if (cutoffString != null) Integer.parseInt(cutoffString);
@@ -72,40 +104,85 @@ public class TrainUtil {
return false;
}
+ String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
+
+ if (dataIndexer != null) {
+ if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
+ return false;
+ }
+ }
+
// TODO: Check data indexing ...
return true;
}
- public static AbstractModel train(EventStream events, Map<String, String> trainParams)
+
+
+ // TODO: Need a way to report results and settings back for inclusion in model ...
+
+ public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap)
throws IOException {
- // if PERCEPTRON or MAXENT
- String algorithmName = trainParams.get(ALGORITHM_PARAM);
+ if (!isValid(trainParams))
+ throw new IllegalArgumentException("trainParams are not valid!");
+
+ if(isSequenceTraining(trainParams))
+ throw new IllegalArgumentException("sequence training is not supported by this method!");
+
+ String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
+
+ int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
+
+ int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
+
+ boolean sortAndMerge;
+
+ if (MAXENT_VALUE.equals(algorithmName))
+ sortAndMerge = true;
+ else if (MAXENT_VALUE.equals(algorithmName))
+ sortAndMerge = false;
+ else
+ throw new IllegalStateException("Unexpected algorihtm name: " + algorithmName);
+
+ HashSumEventStream hses = new HashSumEventStream(events);
- // String DataIndexing -> OnePass|TwoPass
- // TODO: Make data indexing configurable ...
+ String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
+ DATA_INDEXER_TWO_PASS_VALUE, reportMap);
+
+ DataIndexer indexer = null;
- int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT);
- int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT);
+ if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
+ indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
+ }
+ else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
+ indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
+ }
+ else {
+ throw new IllegalStateException("Unexpected data indexer name: " + dataIndexerName);
+ }
AbstractModel model;
if (MAXENT_VALUE.equals(algorithmName)) {
- model = opennlp.maxent.GIS.trainModel(iterations,
- new TwoPassDataIndexer(events, cutoff));
+
+ // TODO: Pass in number of threads
+// int threads = getIntParam(trainParams, "Threads", 1, reportMap);
+
+ model = opennlp.maxent.GIS.trainModel(iterations, indexer);
}
else if (PERCEPTRON_VALUE.equals(algorithmName)) {
- boolean useAverage = true; // <- read from params
- boolean sort = false; // <- read from params
+ boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
model = new opennlp.perceptron.PerceptronTrainer().trainModel(
- iterations, new TwoPassDataIndexer(events,
- cutoff, sort), cutoff, useAverage);
+ iterations, indexer, cutoff, useAverage);
}
else {
throw new IllegalStateException("Algorithm not supported: " + algorithmName);
}
+ if (reportMap != null)
+ reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
+
return model;
}
@@ -114,22 +191,22 @@ public class TrainUtil {
* or not.
*/
public static boolean isSequenceTraining(Map<String, String> trainParams) {
-
- String algorithmName = trainParams.get(ALGORITHM_PARAM);
-
- return PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName);
+ return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
}
- public static AbstractModel train(SequenceStream events, Map<String, String> trainParams)
- throws IOException {
+ public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
+ Map<String, String> reportMap) throws IOException {
+ if (!isValid(trainParams))
+ throw new IllegalArgumentException("trainParams are not valid!");
+
if (!isSequenceTraining(trainParams))
throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
- int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT);
- int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT);
+ int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
+ int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
- boolean useAverage = true; // <- TODO: read from params
+ boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
return new SimplePerceptronSequenceTrainer().trainModel(
iterations, events, cutoff,useAverage);