You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2011/05/23 15:59:30 UTC
svn commit: r1126493 - in
/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp:
maxent/GIS.java maxent/GISTrainer.java model/TrainUtil.java
Author: joern
Date: Mon May 23 13:59:30 2011
New Revision: 1126493
URL: http://svn.apache.org/viewvc?rev=1126493&view=rev
Log:
OPENNLP-29 Added multi threaded GIS training support
Modified:
incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java
incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java
incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GIS.java Mon May 23 13:59:30 2011
@@ -24,6 +24,7 @@ import java.io.IOException;
import opennlp.model.DataIndexer;
import opennlp.model.EventStream;
import opennlp.model.Prior;
+import opennlp.model.UniformPrior;
/**
* A Factory class which uses instances of GISTrainer to create and train
@@ -219,14 +220,40 @@ public class GIS {
public static GISModel trainModel(int iterations, DataIndexer indexer,
boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior,
int cutoff) {
+ return trainModel(iterations, indexer, printMessagesWhileTraining,
+ smoothing, modelPrior, cutoff, 1);
+ }
+
+ /**
+ * Train a model using the GIS algorithm.
+ *
+ * @param iterations
+ * The number of GIS iterations to perform.
+ * @param indexer
+ * The object which will be used for event compilation.
+ * @param printMessagesWhileTraining
+ * Determines whether training status messages are written to STDOUT.
+ * @param smoothing
+ * Defines whether the created trainer will use smoothing while
+ * training the model.
+ * @param modelPrior
+ * The prior distribution for the model.
+ * @param cutoff
+ * The number of times a predicate must occur to be used in a model.
+ * @return The newly trained model, which can be used immediately or saved to
+ * disk using an opennlp.maxent.io.GISModelWriter object.
+ */
+ public static GISModel trainModel(int iterations, DataIndexer indexer,
+ boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior,
+ int cutoff, int threads) {
GISTrainer trainer = new GISTrainer(printMessagesWhileTraining);
trainer.setSmoothing(smoothing);
trainer.setSmoothingObservation(SMOOTHING_OBSERVATION);
- if (modelPrior != null) {
- return trainer.trainModel(iterations, indexer, modelPrior, cutoff);
- } else {
- return trainer.trainModel(iterations, indexer, cutoff);
+ if (modelPrior == null) {
+ modelPrior = new UniformPrior();
}
+
+ return trainer.trainModel(iterations, indexer, modelPrior, cutoff, threads);
}
}
Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/GISTrainer.java Mon May 23 13:59:30 2011
@@ -20,6 +20,13 @@
package opennlp.maxent;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
@@ -135,10 +142,10 @@ class GISTrainer {
*/
private MutableContext[] params;
- /**
- * Stores the expected values of the features based on the current models
+ /**
+ * Stores the expected values of the features based on the current models
*/
- private MutableContext[] modelExpects;
+ private MutableContext[][] modelExpects;
/**
* This is the prior distribution that the model uses for training.
@@ -227,7 +234,7 @@ class GISTrainer {
* to disk using an opennlp.maxent.io.GISModelWriter object.
*/
public GISModel trainModel(int iterations, DataIndexer di, int cutoff) {
- return trainModel(iterations,di,new UniformPrior(),cutoff);
+ return trainModel(iterations,di,new UniformPrior(),cutoff,1);
}
/**
@@ -239,7 +246,13 @@ class GISTrainer {
* @return The newly trained model, which can be used immediately or saved
* to disk using an opennlp.maxent.io.GISModelWriter object.
*/
- public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff) {
+ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) {
+
+ if (threads <= 0)
+ throw new IllegalArgumentException("threads must be at leat one or greater!");
+
+ modelExpects = new MutableContext[threads][];
+
/************** Incorporate all of the needed info ******************/
display("Incorporating indexed data for training... \n");
contexts = di.getContexts();
@@ -311,7 +324,8 @@ class GISTrainer {
// implementation, this is cancelled out when we compute the next
// iteration of a parameter, making the extra divisions wasteful.
params = new MutableContext[numPreds];
- modelExpects = new MutableContext[numPreds];
+ for (int i = 0; i< modelExpects.length; i++)
+ modelExpects[i] = new MutableContext[numPreds];
observedExpects = new MutableContext[numPreds];
// The model does need the correction constant and the correction feature. The correction constant
@@ -350,12 +364,14 @@ class GISTrainer {
}
}
params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
- modelExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
+ for (int i = 0; i< modelExpects.length; i++)
+ modelExpects[i][pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
for (int aoi=0;aoi<numActiveOutcomes;aoi++) {
int oi = outcomePattern[aoi];
params[pi].setParameter(aoi, 0.0);
- modelExpects[pi].setParameter(aoi, 0.0);
+ for (int i = 0; i< modelExpects.length; i++)
+ modelExpects[i][pi].setParameter(aoi, 0.0);
if (predCount[pi][oi] > 0) {
observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
}
@@ -416,14 +432,13 @@ class GISTrainer {
double param = params[predicate].getParameters()[oid];
double x = 0.0;
double x0 = 0.0;
- double f;
double tmp;
double fp;
- double modelValue = modelExpects[predicate].getParameters()[oid];
+ double modelValue = modelExpects[0][predicate].getParameters()[oid];
double observedValue = observedExpects[predicate].getParameters()[oid];
for (int i = 0; i < 50; i++) {
tmp = modelValue * Math.exp(correctionConstant * x0);
- f = tmp + (param + x0) / sigma - observedValue;
+ double f = tmp + (param + x0) / sigma - observedValue;
fp = tmp * correctionConstant + 1 / sigma;
if (fp == 0) {
break;
@@ -438,61 +453,158 @@ class GISTrainer {
return x0;
}
+ private class ModelExpactationComputeTask implements Callable<ModelExpactationComputeTask> {
+
+ private final int startIndex;
+ private final int length;
+
+ private double loglikelihood = 0;
+
+ private int numEvents = 0;
+ private int numCorrect = 0;
+
+ final private int threadIndex;
+
+ // startIndex to compute, number of events to compute
+ ModelExpactationComputeTask(int threadIndex, int startIndex, int length) {
+ this.startIndex = startIndex;
+ this.length = length;
+ this.threadIndex = threadIndex;
+ }
+
+ public ModelExpactationComputeTask call() {
+
+ final double[] modelDistribution = new double[numOutcomes];
+
+
+ for (int ei = startIndex; ei < startIndex + length; ei++) {
+
+ // TODO: check interruption status here, if interrupted set a poisoned flag and return
+
+ if (values != null) {
+ prior.logPrior(modelDistribution, contexts[ei], values[ei]);
+ GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
+ }
+ else {
+ prior.logPrior(modelDistribution,contexts[ei]);
+ GISModel.eval(contexts[ei], modelDistribution, evalParams);
+ }
+ for (int j = 0; j < contexts[ei].length; j++) {
+ int pi = contexts[ei][j];
+ if (predicateCounts[pi] >= cutoff) {
+ int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes();
+ for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
+ int oi = activeOutcomes[aoi];
+
+ // TODO: Read and write to modelExpects must be thread safe ...
+ // numTimesEventsSeen must also be thread safe
+ if (values != null && values[ei] != null) {
+ modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
+ }
+ else {
+ modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);
+ }
+ }
+ }
+ }
+
+ loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
+
+ numEvents += numTimesEventsSeen[ei];
+ if (printMessages) {
+ int max = 0;
+ for (int oi = 1; oi < numOutcomes; oi++) {
+ if (modelDistribution[oi] > modelDistribution[max]) {
+ max = oi;
+ }
+ }
+ if (max == outcomeList[ei]) {
+ numCorrect += numTimesEventsSeen[ei];
+ }
+ }
+
+ }
+
+ return this;
+ }
+
+ synchronized int getNumEvents() {
+ return numEvents;
+ }
+
+ synchronized int getNumCorrect() {
+ return numCorrect;
+ }
+
+ synchronized double getLoglikelihood() {
+ return loglikelihood;
+ }
+ }
+
/* Compute one iteration of GIS and retutn log-likelihood.*/
private double nextIteration(int correctionConstant) {
// compute contribution of p(a|b_i) for each feature and the new
// correction parameter
- double[] modelDistribution = new double[numOutcomes];
double loglikelihood = 0.0;
int numEvents = 0;
int numCorrect = 0;
- for (int ei = 0; ei < numUniqueEvents; ei++) {
- if (values != null) {
- prior.logPrior(modelDistribution,contexts[ei],values[ei]);
- GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
- }
- else {
- prior.logPrior(modelDistribution,contexts[ei]);
- GISModel.eval(contexts[ei], modelDistribution, evalParams);
+
+ int numberOfThreads = modelExpects.length;
+
+ ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
+
+ int taskSize = numUniqueEvents / numberOfThreads;
+
+ int leftOver = numUniqueEvents % numberOfThreads;
+
+ List<Future<?>> futures = new ArrayList<Future<?>>();
+
+ for (int i = 0; i < numberOfThreads; i++) {
+ if (i != numberOfThreads - 1)
+ futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize)));
+ else
+ futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver)));
+ }
+
+ for (Future<?> future : futures) {
+ ModelExpactationComputeTask finishedTask = null;
+ try {
+ finishedTask = (ModelExpactationComputeTask) future.get();
+ } catch (InterruptedException e) {
+ // In case we get interuppted, the exception should be rethrown
+ // and the executor services shutdownNow should be called, to stop any work
+ e.printStackTrace();
+ } catch (ExecutionException e) {
+ e.printStackTrace();
}
- for (int j = 0; j < contexts[ei].length; j++) {
- int pi = contexts[ei][j];
- if (predicateCounts[pi] >= cutoff) {
- int[] activeOutcomes = modelExpects[pi].getOutcomes();
- for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
- int oi = activeOutcomes[aoi];
- if (values != null && values[ei] != null) {
- modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
- }
- else {
- modelExpects[pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);
- }
- }
- }
- }
+ // When they are done, retrieve the results ...
+ numEvents += finishedTask.getNumEvents();
+ numCorrect += finishedTask.getNumCorrect();
+ loglikelihood += finishedTask.getLoglikelihood();
+ }
+
+ executor.shutdown();
+
+ display(".");
+
+ // merge the results of the two computations
+ for (int pi = 0; pi < numPreds; pi++) {
+ int[] activeOutcomes = params[pi].getOutcomes();
- loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
- numEvents += numTimesEventsSeen[ei];
- if (printMessages) {
- int max = 0;
- for (int oi = 1; oi < numOutcomes; oi++) {
- if (modelDistribution[oi] > modelDistribution[max]) {
- max = oi;
- }
- }
- if (max == outcomeList[ei]) {
- numCorrect += numTimesEventsSeen[ei];
+ for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
+ for (int i = 1; i < modelExpects.length; i++) {
+ modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]);
}
}
-
}
+
display(".");
-
+
// compute the new parameter values
for (int pi = 0; pi < numPreds; pi++) {
double[] observed = observedExpects[pi].getParameters();
- double[] model = modelExpects[pi].getParameters();
+ double[] model = modelExpects[0][pi].getParameters();
int[] activeOutcomes = params[pi].getOutcomes();
for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
if (useGaussianSmoothing) {
@@ -505,7 +617,10 @@ class GISTrainer {
//params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi])));
params[pi].updateParameter(aoi,((Math.log(observed[aoi]) - Math.log(model[aoi]))/correctionConstant));
}
- modelExpects[pi].setParameter(aoi,0.0); // re-initialize to 0.0's
+
+ for (int i = 0; i< modelExpects.length; i++)
+ modelExpects[i][pi].setParameter(aoi,0.0); // re-initialize to 0.0's
+
}
}
@@ -518,5 +633,4 @@ class GISTrainer {
if (printMessages)
System.out.print(s);
}
-
}
Modified: incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java?rev=1126493&r1=1126492&r2=1126493&view=diff
==============================================================================
--- incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java (original)
+++ incubator/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java Mon May 23 13:59:30 2011
@@ -165,10 +165,10 @@ public class TrainUtil {
AbstractModel model;
if (MAXENT_VALUE.equals(algorithmName)) {
- // TODO: Pass in number of threads
-// int threads = getIntParam(trainParams, "Threads", 1, reportMap);
+ int threads = getIntParam(trainParams, "Threads", 1, reportMap);
- model = opennlp.maxent.GIS.trainModel(iterations, indexer);
+ model = opennlp.maxent.GIS.trainModel(iterations, indexer,
+ true, false, null, 0, threads);
}
else if (PERCEPTRON_VALUE.equals(algorithmName)) {
boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);