You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2016/12/20 09:04:42 UTC
opennlp git commit: Speed up GIS training by saving Executor in the
GISTrainer
Repository: opennlp
Updated Branches:
refs/heads/trunk 06371807f -> 664d80330
Speed up GIS training by saving Executor in the GISTrainer
Thanks to Daniel Russ for providing a patch
See issue OPENNLP-759
Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo
Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/664d8033
Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/664d8033
Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/664d8033
Branch: refs/heads/trunk
Commit: 664d80330a3b4df05a524c3c202fe3f4de2806ba
Parents: 0637180
Author: Joern Kottmann <ko...@gmail.com>
Authored: Mon Dec 19 13:01:43 2016 +0100
Committer: Joern Kottmann <jo...@apache.org>
Committed: Tue Dec 20 09:40:28 2016 +0100
----------------------------------------------------------------------
.../opennlp/tools/ml/maxent/GISTrainer.java | 34 +++++++++++---------
1 file changed, 18 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/opennlp/blob/664d8033/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
----------------------------------------------------------------------
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
index 0527979..9919bb0 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
@@ -23,7 +23,9 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
+import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@@ -403,6 +405,9 @@ class GISTrainer {
/* Estimate and return the model parameters. */
private void findParameters(int iterations, double correctionConstant) {
+ int threads=modelExpects.length;
+ ExecutorService executor = Executors.newFixedThreadPool(threads);
+ CompletionService<ModelExpactationComputeTask> completionService=new ExecutorCompletionService<GISTrainer.ModelExpactationComputeTask>(executor);
double prevLL = 0.0;
double currLL;
display("Performing " + iterations + " iterations.\n");
@@ -413,7 +418,7 @@ class GISTrainer {
display(" " + i + ": ");
else
display(i + ": ");
- currLL = nextIteration(correctionConstant);
+ currLL = nextIteration(correctionConstant,completionService);
if (i > 1) {
if (prevLL > currLL) {
System.err.println("Model Diverging: loglikelihood decreased");
@@ -431,6 +436,7 @@ class GISTrainer {
modelExpects = null;
numTimesEventsSeen = null;
contexts = null;
+ executor.shutdown();
}
//modeled on implementation in Zhang Le's maxent kit
@@ -544,34 +550,32 @@ class GISTrainer {
}
/* Compute one iteration of GIS and retutn log-likelihood.*/
- private double nextIteration(double correctionConstant) {
+ private double nextIteration(double correctionConstant, CompletionService<ModelExpactationComputeTask> completionService) {
// compute contribution of p(a|b_i) for each feature and the new
// correction parameter
double loglikelihood = 0.0;
int numEvents = 0;
int numCorrect = 0;
+ // Each thread gets equal number of tasks, if the number of tasks
+ // is not divisible by the number of threads, the first "leftOver"
+ // threads have one extra task.
int numberOfThreads = modelExpects.length;
-
- ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
-
int taskSize = numUniqueEvents / numberOfThreads;
-
int leftOver = numUniqueEvents % numberOfThreads;
- List<Future<?>> futures = new ArrayList<Future<?>>();
-
+ // submit all tasks to the completion service.
for (int i = 0; i < numberOfThreads; i++) {
- if (i != numberOfThreads - 1)
- futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize)));
+ if (i < leftOver)
+ completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+i, taskSize+1));
else
- futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver)));
+ completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+leftOver, taskSize));
}
- for (Future<?> future : futures) {
- ModelExpactationComputeTask finishedTask;
+ for (int i=0; i<numberOfThreads; i++) {
+ ModelExpactationComputeTask finishedTask = null;
try {
- finishedTask = (ModelExpactationComputeTask) future.get();
+ finishedTask = completionService.take().get();
} catch (InterruptedException e) {
// TODO: We got interrupted, but that is currently not really supported!
// For now we just print the exception and fail hard. We hopefully soon
@@ -591,8 +595,6 @@ class GISTrainer {
loglikelihood += finishedTask.getLoglikelihood();
}
- executor.shutdown();
-
display(".");
// merge the results of the two computations