Repository: opennlp Updated Branches: refs/heads/trunk 06371807f -> 664d80330
Speed up GIS training by saving Executor in the GISTrainer Thanks to Daniel Russ for providing a patch See issue OPENNLP-759 Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/664d8033 Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/664d8033 Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/664d8033 Branch: refs/heads/trunk Commit: 664d80330a3b4df05a524c3c202fe3f4de2806ba Parents: 0637180 Author: Joern Kottmann <[email protected]> Authored: Mon Dec 19 13:01:43 2016 +0100 Committer: Joern Kottmann <[email protected]> Committed: Tue Dec 20 09:40:28 2016 +0100 ---------------------------------------------------------------------- .../opennlp/tools/ml/maxent/GISTrainer.java | 34 +++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp/blob/664d8033/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java index 0527979..9919bb0 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java @@ -23,7 +23,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -403,6 +405,9 @@ class GISTrainer { /* Estimate and return the model parameters. */ private void findParameters(int iterations, double correctionConstant) { + int threads=modelExpects.length; + ExecutorService executor = Executors.newFixedThreadPool(threads); + CompletionService<ModelExpactationComputeTask> completionService=new ExecutorCompletionService<GISTrainer.ModelExpactationComputeTask>(executor); double prevLL = 0.0; double currLL; display("Performing " + iterations + " iterations.\n"); @@ -413,7 +418,7 @@ class GISTrainer { display(" " + i + ": "); else display(i + ": "); - currLL = nextIteration(correctionConstant); + currLL = nextIteration(correctionConstant,completionService); if (i > 1) { if (prevLL > currLL) { System.err.println("Model Diverging: loglikelihood decreased"); @@ -431,6 +436,7 @@ class GISTrainer { modelExpects = null; numTimesEventsSeen = null; contexts = null; + executor.shutdown(); } //modeled on implementation in Zhang Le's maxent kit @@ -544,34 +550,32 @@ class GISTrainer { } /* Compute one iteration of GIS and retutn log-likelihood.*/ - private double nextIteration(double correctionConstant) { + private double nextIteration(double correctionConstant, CompletionService<ModelExpactationComputeTask> completionService) { // compute contribution of p(a|b_i) for each feature and the new // correction parameter double loglikelihood = 0.0; int numEvents = 0; int numCorrect = 0; + // Each thread gets equal number of tasks, if the number of tasks + // is not divisible by the number of threads, the first "leftOver" + // threads have one extra task. int numberOfThreads = modelExpects.length; - - ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads); - int taskSize = numUniqueEvents / numberOfThreads; - int leftOver = numUniqueEvents % numberOfThreads; - List<Future<?>> futures = new ArrayList<Future<?>>(); - + // submit all tasks to the completion service. for (int i = 0; i < numberOfThreads; i++) { - if (i != numberOfThreads - 1) - futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize))); + if (i < leftOver) + completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+i, taskSize+1)); else - futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver))); + completionService.submit(new ModelExpactationComputeTask(i, i*taskSize+leftOver, taskSize)); } - for (Future<?> future : futures) { - ModelExpactationComputeTask finishedTask; + for (int i=0; i<numberOfThreads; i++) { + ModelExpactationComputeTask finishedTask = null; try { - finishedTask = (ModelExpactationComputeTask) future.get(); + finishedTask = completionService.take().get(); } catch (InterruptedException e) { // TODO: We got interrupted, but that is currently not really supported! // For now we just print the exception and fail hard. We hopefully soon @@ -591,8 +595,6 @@ class GISTrainer { loglikelihood += finishedTask.getLoglikelihood(); } - executor.shutdown(); - display("."); // merge the results of the two computations
