You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2011/05/23 15:59:30 UTC

svn commit: r1126493 - in /incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp: maxent/GIS.java maxent/GISTrainer.java model/TrainUtil.java

Author: joern
Date: Mon May 23 13:59:30 2011
New Revision: 1126493

URL: http://svn.apache.org/viewvc?rev=1126493&view=rev
Log:
OPENNLP-29 Added multi threaded GIS training support

Modified:
    incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java
    incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java
    incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java

Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java Mon May 23 13:59:30 2011
@@ -24,6 +24,7 @@ import java.io.IOException;
 import opennlp.model.DataIndexer;
 import opennlp.model.EventStream;
 import opennlp.model.Prior;
+import opennlp.model.UniformPrior;
 
 /**
  * A Factory class which uses instances of GISTrainer to create and train
@@ -219,14 +220,40 @@ public class GIS {
   public static GISModel trainModel(int iterations, DataIndexer indexer,
       boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior,
       int cutoff) {
+    return trainModel(iterations, indexer, printMessagesWhileTraining,
+        smoothing, modelPrior, cutoff, 1);
+  }
+  
+  /**
+   * Train a model using the GIS algorithm.
+   * 
+   * @param iterations
+   *          The number of GIS iterations to perform.
+   * @param indexer
+   *          The object which will be used for event compilation.
+   * @param printMessagesWhileTraining
+   *          Determines whether training status messages are written to STDOUT.
+   * @param smoothing
+   *          Defines whether the created trainer will use smoothing while
+   *          training the model.
+   * @param modelPrior
+   *          The prior distribution for the model.
+   * @param cutoff
+   *          The number of times a predicate must occur to be used in a model.
+   * @return The newly trained model, which can be used immediately or saved to
+   *         disk using an opennlp.maxent.io.GISModelWriter object.
+   */
+  public static GISModel trainModel(int iterations, DataIndexer indexer,
+      boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior,
+      int cutoff, int threads) {
     GISTrainer trainer = new GISTrainer(printMessagesWhileTraining);
     trainer.setSmoothing(smoothing);
     trainer.setSmoothingObservation(SMOOTHING_OBSERVATION);
-    if (modelPrior != null) {
-      return trainer.trainModel(iterations, indexer, modelPrior, cutoff);
-    } else {
-      return trainer.trainModel(iterations, indexer, cutoff);
+    if (modelPrior == null) {
+      modelPrior = new UniformPrior();
     }
+    
+    return trainer.trainModel(iterations, indexer, modelPrior, cutoff, threads);
   }
 }
 

Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java Mon May 23 13:59:30 2011
@@ -20,6 +20,13 @@
 package opennlp.maxent;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 
 import opennlp.model.DataIndexer;
 import opennlp.model.EvalParameters;
@@ -135,10 +142,10 @@ class GISTrainer {
    */
   private MutableContext[] params;
 
-  /**
-   * Stores the expected values of the features based on the current models
+  /** 
+   * Stores the expected values of the features based on the current models 
    */
-  private MutableContext[] modelExpects;
+  private MutableContext[][] modelExpects;
 
   /**
    * This is the prior distribution that the model uses for training.
@@ -227,7 +234,7 @@ class GISTrainer {
    *         to disk using an opennlp.maxent.io.GISModelWriter object.
    */
   public GISModel trainModel(int iterations, DataIndexer di, int cutoff) {
-    return trainModel(iterations,di,new UniformPrior(),cutoff);
+    return trainModel(iterations,di,new UniformPrior(),cutoff,1);
   }
 
   /**
@@ -239,7 +246,13 @@ class GISTrainer {
    * @return The newly trained model, which can be used immediately or saved
    *         to disk using an opennlp.maxent.io.GISModelWriter object.
    */
-  public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff) {
+  public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) {
+    
+    if (threads <= 0)
+      throw new IllegalArgumentException("threads must be at leat one or greater!");
+    
+    modelExpects = new MutableContext[threads][];
+    
     /************** Incorporate all of the needed info ******************/
     display("Incorporating indexed data for training...  \n");
     contexts = di.getContexts();
@@ -311,7 +324,8 @@ class GISTrainer {
     // implementation, this is cancelled out when we compute the next
     // iteration of a parameter, making the extra divisions wasteful.
     params = new MutableContext[numPreds];
-    modelExpects = new MutableContext[numPreds];
+    for (int i = 0; i< modelExpects.length; i++)
+      modelExpects[i] = new MutableContext[numPreds];
     observedExpects = new MutableContext[numPreds];
     
     // The model does need the correction constant and the correction feature. The correction constant
@@ -350,12 +364,14 @@ class GISTrainer {
         }
       }
       params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
-      modelExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
+      for (int i = 0; i< modelExpects.length; i++)
+        modelExpects[i][pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
       observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
       for (int aoi=0;aoi<numActiveOutcomes;aoi++) {
         int oi = outcomePattern[aoi];
         params[pi].setParameter(aoi, 0.0);
-        modelExpects[pi].setParameter(aoi, 0.0);
+        for (int i = 0; i< modelExpects.length; i++)
+          modelExpects[i][pi].setParameter(aoi, 0.0);
         if (predCount[pi][oi] > 0) {
             observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
         }
@@ -416,14 +432,13 @@ class GISTrainer {
     double param = params[predicate].getParameters()[oid];
     double x = 0.0;
     double x0 = 0.0;
-    double f;
     double tmp;
     double fp;
-    double modelValue = modelExpects[predicate].getParameters()[oid];
+    double modelValue = modelExpects[0][predicate].getParameters()[oid];
     double observedValue = observedExpects[predicate].getParameters()[oid];
     for (int i = 0; i < 50; i++) {
       tmp = modelValue * Math.exp(correctionConstant * x0);
-      f = tmp + (param + x0) / sigma - observedValue;
+      double f = tmp + (param + x0) / sigma - observedValue;
       fp = tmp * correctionConstant + 1 / sigma;
       if (fp == 0) {
         break;
@@ -438,61 +453,158 @@ class GISTrainer {
     return x0;
   }
   
+  private class ModelExpactationComputeTask implements Callable<ModelExpactationComputeTask> {
+
+    private final int startIndex;
+    private final int length; 
+    
+    private double loglikelihood = 0;
+    
+    private int numEvents = 0;
+    private int numCorrect = 0;
+    
+    final private int threadIndex;
+
+    // startIndex to compute, number of events to compute
+    ModelExpactationComputeTask(int threadIndex, int startIndex, int length) {
+      this.startIndex = startIndex;
+      this.length = length;
+      this.threadIndex = threadIndex;
+    }
+    
+    public ModelExpactationComputeTask call() {
+      
+      final double[] modelDistribution = new double[numOutcomes];
+      
+      
+      for (int ei = startIndex; ei < startIndex + length; ei++) {
+        
+        // TODO: check interruption status here, if interrupted set a poisoned flag and return
+        
+        if (values != null) {
+          prior.logPrior(modelDistribution, contexts[ei], values[ei]); 
+          GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
+        }
+        else {
+          prior.logPrior(modelDistribution,contexts[ei]);
+          GISModel.eval(contexts[ei], modelDistribution, evalParams);
+        }
+        for (int j = 0; j < contexts[ei].length; j++) {
+          int pi = contexts[ei][j];
+          if (predicateCounts[pi] >= cutoff) {
+            int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes();
+            for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
+              int oi = activeOutcomes[aoi];
+              
+              // TODO: Read and write to modelExpects must be thread safe ... 
+              // numTimesEventsSeen must also be thread safe
+              if (values != null && values[ei] != null) {
+                modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
+              }
+              else {
+                modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);
+              }
+            }
+          }
+        }
+        
+        loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
+        
+        numEvents += numTimesEventsSeen[ei];
+        if (printMessages) {
+          int max = 0;
+          for (int oi = 1; oi < numOutcomes; oi++) {
+            if (modelDistribution[oi] > modelDistribution[max]) {
+              max = oi;
+            }
+          }
+          if (max == outcomeList[ei]) {
+            numCorrect += numTimesEventsSeen[ei];
+          }
+        }
+        
+      }
+      
+      return this;
+    }
+    
+    synchronized int getNumEvents() {
+      return numEvents;
+    }
+    
+    synchronized int getNumCorrect() {
+      return numCorrect;
+    }
+    
+    synchronized double getLoglikelihood() {
+      return loglikelihood;
+    }
+  }
+  
   /* Compute one iteration of GIS and retutn log-likelihood.*/
   private double nextIteration(int correctionConstant) {
     // compute contribution of p(a|b_i) for each feature and the new
     // correction parameter
-    double[] modelDistribution = new double[numOutcomes];
     double loglikelihood = 0.0;
     int numEvents = 0;
     int numCorrect = 0;
-    for (int ei = 0; ei < numUniqueEvents; ei++) {
-      if (values != null) {
-        prior.logPrior(modelDistribution,contexts[ei],values[ei]);
-        GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
-      }
-      else {
-        prior.logPrior(modelDistribution,contexts[ei]);
-        GISModel.eval(contexts[ei], modelDistribution, evalParams);
+    
+    int numberOfThreads = modelExpects.length;
+    
+    ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
+    
+    int taskSize = numUniqueEvents / numberOfThreads;
+    
+    int leftOver = numUniqueEvents % numberOfThreads;
+    
+    List<Future<?>> futures = new ArrayList<Future<?>>();
+    
+    for (int i = 0; i < numberOfThreads; i++) {
+      if (i != numberOfThreads - 1)
+        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize)));
+      else 
+        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver)));
+    }
+    
+    for (Future<?> future : futures) {
+      ModelExpactationComputeTask finishedTask = null;
+      try {
+        finishedTask = (ModelExpactationComputeTask) future.get();
+      } catch (InterruptedException e) {
+        // In case we get interuppted, the exception should be rethrown
+        // and the executor services shutdownNow should be called, to stop any work
+        e.printStackTrace();
+      } catch (ExecutionException e) {
+        e.printStackTrace();
       }
       
-      for (int j = 0; j < contexts[ei].length; j++) {
-        int pi = contexts[ei][j];
-        if (predicateCounts[pi] >= cutoff) {
-          int[] activeOutcomes = modelExpects[pi].getOutcomes();
-          for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
-            int oi = activeOutcomes[aoi];
-            if (values != null && values[ei] != null) {
-              modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
-            }
-            else {
-              modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);
-            }
-          }
-        }
-      }
+      // When they are done, retrieve the results ...
+      numEvents += finishedTask.getNumEvents();
+      numCorrect += finishedTask.getNumCorrect();
+      loglikelihood += finishedTask.getLoglikelihood();
+    }
+
+    executor.shutdown();
+    
+    display(".");
+
+    // merge the results of the two computations
+    for (int pi = 0; pi < numPreds; pi++) {
+      int[] activeOutcomes = params[pi].getOutcomes();
       
-      loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
-      numEvents += numTimesEventsSeen[ei];
-      if (printMessages) {
-        int max = 0;
-        for (int oi = 1; oi < numOutcomes; oi++) {
-          if (modelDistribution[oi] > modelDistribution[max]) {
-            max = oi;
-          }
-        }
-        if (max == outcomeList[ei]) {
-          numCorrect += numTimesEventsSeen[ei];
+      for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
+        for (int i = 1; i < modelExpects.length; i++) {
+          modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]);
         }
       }
-
     }
+    
     display(".");
-
+    
     // compute the new parameter values
     for (int pi = 0; pi < numPreds; pi++) {
       double[] observed = observedExpects[pi].getParameters();
-      double[] model = modelExpects[pi].getParameters();
+      double[] model = modelExpects[0][pi].getParameters();
       int[] activeOutcomes = params[pi].getOutcomes();
       for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
         if (useGaussianSmoothing) {
@@ -505,7 +617,10 @@ class GISTrainer {
           //params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi])));
           params[pi].updateParameter(aoi,((Math.log(observed[aoi]) - Math.log(model[aoi]))/correctionConstant));
         }
-        modelExpects[pi].setParameter(aoi,0.0); // re-initialize to 0.0's
+        
+        for (int i = 0; i< modelExpects.length; i++)
+          modelExpects[i][pi].setParameter(aoi,0.0); // re-initialize to 0.0's
+
       }
     }
 
@@ -518,5 +633,4 @@ class GISTrainer {
     if (printMessages)
       System.out.print(s);
   }
-
 }

Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java Mon May 23 13:59:30 2011
@@ -165,10 +165,10 @@ public class TrainUtil {
     AbstractModel model;
     if (MAXENT_VALUE.equals(algorithmName)) {
       
-      // TODO: Pass in number of threads
-//      int threads = getIntParam(trainParams, "Threads", 1, reportMap);
+      int threads = getIntParam(trainParams, "Threads", 1, reportMap);
       
-      model = opennlp.maxent.GIS.trainModel(iterations, indexer);
+      model = opennlp.maxent.GIS.trainModel(iterations, indexer,
+          true, false, null, 0, threads);
     }
     else if (PERCEPTRON_VALUE.equals(algorithmName)) {
       boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);