You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2016/12/20 09:04:42 UTC

opennlp git commit: Speed up GIS training by saving Executor in the GISTrainer

Repository: opennlp
Updated Branches:
  refs/heads/trunk 06371807f -> 664d80330


Speed up GIS training by saving Executor in the GISTrainer

Thanks to Daniel Russ  for providing a patch

See issue OPENNLP-759


Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo
Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/664d8033
Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/664d8033
Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/664d8033

Branch: refs/heads/trunk
Commit: 664d80330a3b4df05a524c3c202fe3f4de2806ba
Parents: 0637180
Author: Joern Kottmann <ko...@gmail.com>
Authored: Mon Dec 19 13:01:43 2016 +0100
Committer: Joern Kottmann <jo...@apache.org>
Committed: Tue Dec 20 09:40:28 2016 +0100

----------------------------------------------------------------------
 .../opennlp/tools/ml/maxent/GISTrainer.java     | 34 +++++++++++---------
 1 file changed, 18 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/opennlp/blob/664d8033/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
----------------------------------------------------------------------
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
index 0527979..9919bb0 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
@@ -23,7 +23,9 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Callable;
+import java.util.concurrent.CompletionService;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorCompletionService;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -403,6 +405,9 @@ class GISTrainer {
 
   /* Estimate and return the model parameters. */
   private void findParameters(int iterations, double correctionConstant) {
+	int threads=modelExpects.length;
+	ExecutorService executor = Executors.newFixedThreadPool(threads);
+	CompletionService<ModelExpactationComputeTask> completionService=new ExecutorCompletionService<GISTrainer.ModelExpactationComputeTask>(executor);
     double prevLL = 0.0;
     double currLL;
     display("Performing " + iterations + " iterations.\n");
@@ -413,7 +418,7 @@ class GISTrainer {
         display(" " + i + ":  ");
       else
         display(i + ":  ");
-      currLL = nextIteration(correctionConstant);
+      currLL = nextIteration(correctionConstant,completionService);
       if (i > 1) {
         if (prevLL > currLL) {
           System.err.println("Model Diverging: loglikelihood decreased");
@@ -431,6 +436,7 @@ class GISTrainer {
     modelExpects = null;
     numTimesEventsSeen = null;
     contexts = null;
+    executor.shutdown();
   }
 
   //modeled on implementation in  Zhang Le's maxent kit
@@ -544,34 +550,32 @@ class GISTrainer {
   }
 
   /* Compute one iteration of GIS and retutn log-likelihood.*/
-  private double nextIteration(double correctionConstant) {
+  private double nextIteration(double correctionConstant, CompletionService<ModelExpactationComputeTask> completionService) {
     // compute contribution of p(a|b_i) for each feature and the new
     // correction parameter
     double loglikelihood = 0.0;
     int numEvents = 0;
     int numCorrect = 0;
 
+    // Each thread gets equal number of tasks, if the number of tasks
+    // is not divisible by the number of threads, the first "leftOver" 
+    // threads have one extra task.
     int numberOfThreads = modelExpects.length;
-
-    ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
-
     int taskSize = numUniqueEvents / numberOfThreads;
-
     int leftOver = numUniqueEvents % numberOfThreads;
 
-    List<Future<?>> futures = new ArrayList<Future<?>>();
-
+    // submit all tasks to the completion service.
     for (int i = 0; i < numberOfThreads; i++) {
-      if (i != numberOfThreads - 1)
-        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize)));
+      if (i < leftOver)
+        completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+i, taskSize+1));
       else
-        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver)));
+        completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+leftOver, taskSize));
     }
 
-    for (Future<?> future : futures) {
-      ModelExpactationComputeTask finishedTask;
+    for (int i=0; i<numberOfThreads; i++) {
+      ModelExpactationComputeTask finishedTask = null;
       try {
-        finishedTask = (ModelExpactationComputeTask) future.get();
+        finishedTask = completionService.take().get();
       } catch (InterruptedException e) {
         // TODO: We got interrupted, but that is currently not really supported!
         // For now we just print the exception and fail hard. We hopefully soon
@@ -591,8 +595,6 @@ class GISTrainer {
       loglikelihood += finishedTask.getLoglikelihood();
     }
 
-    executor.shutdown();
-
     display(".");
 
     // merge the results of the two computations