You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2011/05/19 16:37:31 UTC

svn commit: r1124852 - in /incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model: HashSumEventStream.java TrainUtil.java

Author: joern
Date: Thu May 19 14:37:31 2011
New Revision: 1124852

URL: http://svn.apache.org/viewvc?rev=1124852&view=rev
Log:
OPENNLP-175 Updated to also report training parameters

Added:
    incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java   (with props)
Modified:
    incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java

Added: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java?rev=1124852&view=auto
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java (added)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java Thu May 19 14:37:31 2011
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreemnets.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.model;
+
+import java.io.IOException;
+import java.io.UnsupportedEncodingException;
+import java.math.BigInteger;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+
+import opennlp.model.Event;
+import opennlp.model.EventStream;
+
+public class HashSumEventStream implements EventStream {
+
+  private final EventStream eventStream;
+  
+  private MessageDigest digest;
+  
+  public HashSumEventStream(EventStream eventStream) {
+    this.eventStream = eventStream;
+    
+    try {
+      digest = MessageDigest.getInstance("MD5");
+    } catch (NoSuchAlgorithmException e) {
+      // should never happen, does all java runtimes have md5 ?!
+     throw new IllegalStateException(e);
+    }
+  }
+  
+  public boolean hasNext() throws IOException {
+    return eventStream.hasNext();
+  }
+
+  public Event next() throws IOException {
+    
+    Event event = eventStream.next();
+    
+    try {
+      digest.update(event.toString().getBytes("UTF-8"));
+    }
+    catch (UnsupportedEncodingException e) {
+      throw new IllegalStateException("UTF-8 encoding is not available!");
+    }
+    
+    return event;
+  }
+  
+  /**
+   * Calculates the hash sum of the stream. The method must be
+   * called after the stream is completely consumed.
+   * 
+   * @return the hash sum
+   * @throws IllegalStateException if the stream is not consumed completely,
+   * completely means that hasNext() returns false
+   */
+  public BigInteger calculateHashSum() {
+    
+//    if (hasNext())
+//      throw new IllegalStateException("stream must be consumed completely!");
+    
+    return new BigInteger(1, digest.digest());
+  }
+  
+  public void remove() {
+  }
+}

Propchange: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/HashSumEventStream.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java?rev=1124852&r1=1124851&r2=1124852&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java Thu May 19 14:37:31 2011
@@ -20,6 +20,7 @@
 package opennlp.model;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 
 import opennlp.perceptron.SimplePerceptronSequenceTrainer;
@@ -34,33 +35,64 @@ public class TrainUtil {
   
   
   public static final String CUTOFF_PARAM = "Cutoff";
-  public static final String ITERATIONS_PARAM = "Iterations";
+  private static final int CUTOFF_DEFAULT = 5;
   
+  public static final String ITERATIONS_PARAM = "Iterations";
   private static final int ITERATIONS_DEFAULT = 100;
-  private static final int CUTOFF_DEFAULT = 5;
   
+  public static final String DATA_INDEXER_PARAM = "DataIndexer";
+  public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
+  public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
   
-  private static int getIntParam(Map<String, String> trainParams, String key,
-      int defaultValue) {
-    
+  
+  private static String getStringParam(Map<String, String> trainParams, String key,
+      String defaultValue, Map<String, String> reportMap) {
+
     String valueString = trainParams.get(key);
+
+    if (valueString == null)
+      valueString = defaultValue;
     
+    if (reportMap != null)
+      reportMap.put(key, valueString);
+    
+    return valueString;
+  }
+  
+  private static int getIntParam(Map<String, String> trainParams, String key,
+      int defaultValue, Map<String, String> reportMap) {
+
+    String valueString = trainParams.get(key);
+
     if (valueString != null)
       return Integer.parseInt(valueString);
     else
       return defaultValue;
   }
   
+  private static boolean getBooleanParam(Map<String, String> trainParams, String key,
+      boolean defaultValue, Map<String, String> reportMap) {
+
+    String valueString = trainParams.get(key);
+
+    if (valueString != null)
+      return Boolean.parseBoolean(valueString);
+    else
+      return defaultValue;
+  }
+  
   public static boolean isValid(Map<String, String> trainParams) {
+
+    // TODO: Need to validate all parameters correctly ... error prone?!
     
     String algorithmName = trainParams.get(ALGORITHM_PARAM);
-    
-    if (!(MAXENT_VALUE.equals(algorithmName) || 
+
+    if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) || 
         PERCEPTRON_VALUE.equals(algorithmName) ||
         PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
       return false;
     }
-    
+
     try {
       String cutoffString = trainParams.get(CUTOFF_PARAM);
       if (cutoffString != null) Integer.parseInt(cutoffString);
@@ -72,40 +104,85 @@ public class TrainUtil {
       return false;
     }
     
+    String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
+    
+    if (dataIndexer != null) {
+      if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
+        return false;
+      }
+    }
+    
     // TODO: Check data indexing ... 
      
     return true;
   }
   
-  public static AbstractModel train(EventStream events, Map<String, String> trainParams) 
+  
+  
+  // TODO: Need a way to report results and settings back for inclusion in model ...
+  
+  public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap) 
       throws IOException {
     
-    // if PERCEPTRON or MAXENT
-    String algorithmName = trainParams.get(ALGORITHM_PARAM);
+    if (!isValid(trainParams))
+        throw new IllegalArgumentException("trainParams are not valid!");
+    
+    if(isSequenceTraining(trainParams))
+      throw new IllegalArgumentException("sequence training is not supported by this method!");
+    
+    String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
+    
+    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
+        
+    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
+
+    boolean sortAndMerge;
+    
+    if (MAXENT_VALUE.equals(algorithmName))
+        sortAndMerge = true;
+    else if (MAXENT_VALUE.equals(algorithmName))
+      sortAndMerge = false;
+    else
+      throw new IllegalStateException("Unexpected algorihtm name: " + algorithmName);
+
+    HashSumEventStream hses = new HashSumEventStream(events);
     
-    // String DataIndexing -> OnePass|TwoPass
-    // TODO: Make data indexing configurable ... 
+    String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
+        DATA_INDEXER_TWO_PASS_VALUE, reportMap);
+
+    DataIndexer indexer = null;
     
-    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT);
-    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT);
+    if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
+      indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
+    }
+    else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
+      indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
+    }
+    else {
+      throw new IllegalStateException("Unexpected data indexer name: " +  dataIndexerName);
+    }
     
     AbstractModel model;
     if (MAXENT_VALUE.equals(algorithmName)) {
-      model = opennlp.maxent.GIS.trainModel(iterations,
-          new TwoPassDataIndexer(events, cutoff));
+      
+      // TODO: Pass in number of threads
+//      int threads = getIntParam(trainParams, "Threads", 1, reportMap);
+      
+      model = opennlp.maxent.GIS.trainModel(iterations, indexer);
     }
     else if (PERCEPTRON_VALUE.equals(algorithmName)) {
-      boolean useAverage = true; // <- read from params 
-      boolean sort = false; // <- read from params
+      boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
       
       model = new opennlp.perceptron.PerceptronTrainer().trainModel(
-          iterations, new TwoPassDataIndexer(events,
-          cutoff, sort), cutoff, useAverage);
+          iterations, indexer, cutoff, useAverage);
     }
     else {
       throw new IllegalStateException("Algorithm not supported: " + algorithmName);
     }
     
+    if (reportMap != null)
+        reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
+    
     return model;
   }
   
@@ -114,22 +191,22 @@ public class TrainUtil {
    * or not.
    */
   public static boolean isSequenceTraining(Map<String, String> trainParams) {
-    
-    String algorithmName = trainParams.get(ALGORITHM_PARAM);
-    
-    return PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName);
+    return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
   }
   
-  public static AbstractModel train(SequenceStream events, Map<String, String> trainParams) 
-      throws IOException {
+  public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
+      Map<String, String> reportMap) throws IOException {
     
+    if (!isValid(trainParams))
+      throw new IllegalArgumentException("trainParams are not valid!");
+  
     if (!isSequenceTraining(trainParams))
       throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
     
-    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT);
-    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT);
+    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
+    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
     
-    boolean useAverage = true; // <- TODO: read from params
+    boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
     
     return new SimplePerceptronSequenceTrainer().trainModel(
         iterations, events, cutoff,useAverage);