This is an automated email from the ASF dual-hosted git repository.
mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 28e2de63 OPENNLP-124: Maxent/Perceptron training should report
progress back via an API (#758)
28e2de63 is described below
commit 28e2de631079401656aa6f4a94b175661ef91cad
Author: NishantShri4 <[email protected]>
AuthorDate: Fri Apr 25 18:59:14 2025 +0100
OPENNLP-124: Maxent/Perceptron training should report progress back via an
API (#758)
* OPENNLP-124 : Maxent/Perceptron training should report progress back via
an API
* OPENNLP-124 : Fixed Review Comments
* OPENNLP-124 : Updated javadoc for the new Trainer.init method
---
.../main/java/opennlp/tools/commons/Trainer.java | 11 ++
.../java/opennlp/tools/ml/AbstractTrainer.java | 20 ++++
.../main/java/opennlp/tools/ml/TrainerFactory.java | 69 +++++------
.../java/opennlp/tools/ml/maxent/GISTrainer.java | 52 ++++++++-
.../tools/ml/perceptron/PerceptronTrainer.java | 59 ++++++++--
.../monitoring/DefaultTrainingProgressMonitor.java | 93 +++++++++++++++
.../IterDeltaAccuracyUnderTolerance.java | 51 ++++++++
.../monitoring/LogLikelihoodThresholdBreached.java | 56 +++++++++
.../opennlp/tools/monitoring/StopCriteria.java} | 30 ++---
.../opennlp/tools/monitoring/TrainingMeasure.java} | 26 ++---
.../tools/monitoring/TrainingProgressMonitor.java | 63 ++++++++++
.../opennlp/tools/util/TrainingConfiguration.java} | 27 ++---
.../java/opennlp/tools/ml/MockEventTrainer.java | 7 ++
.../java/opennlp/tools/ml/MockSequenceTrainer.java | 8 +-
.../java/opennlp/tools/ml/TrainerFactoryTest.java | 20 ++++
.../DefaultTrainingProgressMonitorTest.java | 128 +++++++++++++++++++++
.../IterDeltaAccuracyUnderToleranceTest.java | 52 +++++++++
.../LogLikelihoodThresholdBreachedTest.java | 56 +++++++++
18 files changed, 738 insertions(+), 90 deletions(-)
diff --git a/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java
b/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java
index efd8ee76..ce8ed022 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/commons/Trainer.java
@@ -19,6 +19,7 @@ package opennlp.tools.commons;
import java.util.Map;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
/**
@@ -35,4 +36,14 @@ public interface Trainer {
*/
void init(TrainingParameters trainParams, Map<String, String> reportMap);
+ /**
+ * Conducts the initialization of a {@link Trainer} via
+ * {@link TrainingParameters}, {@link Map report map} and {@link
TrainingConfiguration}
+ *
+ * @param trainParams The {@link TrainingParameters} to use.
+ * @param reportMap The {@link Map} instance used as report map.
+ * @param config The {@link TrainingConfiguration} to use. If null,
suitable defaults will be used.
+ */
+ void init(TrainingParameters trainParams, Map<String, String> reportMap,
TrainingConfiguration config);
+
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
index 54e315c8..2401e35f 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
@@ -22,12 +22,14 @@ import java.util.Map;
import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.maxent.GISTrainer;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
public abstract class AbstractTrainer implements Trainer {
protected TrainingParameters trainingParameters;
protected Map<String,String> reportMap;
+ protected TrainingConfiguration trainingConfiguration;
public AbstractTrainer() {
}
@@ -55,6 +57,16 @@ public abstract class AbstractTrainer implements Trainer {
this.reportMap = reportMap;
}
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void init(TrainingParameters trainParams, Map<String, String>
reportMap,
+ TrainingConfiguration config) {
+ init(trainParams, reportMap);
+ this.trainingConfiguration = config;
+ }
+
/**
* @return Retrieves the configured {@link
TrainingParameters#ALGORITHM_PARAM} value.
*/
@@ -108,4 +120,12 @@ public abstract class AbstractTrainer implements Trainer {
reportMap.put(key, value);
}
+ /**
+ * Retrieves the {@link TrainingConfiguration} associated with an {@link
AbstractTrainer}.
+ * @return {@link TrainingConfiguration}
+ */
+ public TrainingConfiguration getTrainingConfiguration() {
+ return trainingConfiguration;
+ }
+
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
index b47e3a75..6aaf73b6 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
@@ -26,6 +26,8 @@ import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
+import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.ext.ExtensionLoader;
import opennlp.tools.util.ext.ExtensionNotLoadedException;
@@ -62,12 +64,11 @@ public class TrainerFactory {
* {@link TrainingParameters#ALGORITHM_PARAM} value.
*
* @param trainParams - A mapping of {@link TrainingParameters training
parameters}.
- *
* @return The {@link TrainerType} or {@code null} if the type couldn't be
determined.
*/
public static TrainerType getTrainerType(TrainingParameters trainParams) {
- String algorithmValue =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null);
+ String algorithmValue =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null);
// Check if it is defaulting to the MAXENT trainer
if (algorithmValue == null) {
@@ -80,11 +81,9 @@ public class TrainerFactory {
if (EventTrainer.class.isAssignableFrom(trainerClass)) {
return TrainerType.EVENT_MODEL_TRAINER;
- }
- else if (EventModelSequenceTrainer.class.isAssignableFrom(trainerClass))
{
+ } else if
(EventModelSequenceTrainer.class.isAssignableFrom(trainerClass)) {
return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER;
- }
- else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
+ } else if (SequenceTrainer.class.isAssignableFrom(trainerClass)) {
return TrainerType.SEQUENCE_TRAINER;
}
}
@@ -94,24 +93,21 @@ public class TrainerFactory {
try {
ExtensionLoader.instantiateExtension(EventTrainer.class, algorithmValue);
return TrainerType.EVENT_MODEL_TRAINER;
- }
- catch (ExtensionNotLoadedException ignored) {
+ } catch (ExtensionNotLoadedException ignored) {
// this is ignored
}
try {
ExtensionLoader.instantiateExtension(EventModelSequenceTrainer.class,
algorithmValue);
return TrainerType.EVENT_MODEL_SEQUENCE_TRAINER;
- }
- catch (ExtensionNotLoadedException ignored) {
+ } catch (ExtensionNotLoadedException ignored) {
// this is ignored
}
try {
ExtensionLoader.instantiateExtension(SequenceTrainer.class,
algorithmValue);
return TrainerType.SEQUENCE_TRAINER;
- }
- catch (ExtensionNotLoadedException ignored) {
+ } catch (ExtensionNotLoadedException ignored) {
// this is ignored
}
@@ -124,15 +120,14 @@ public class TrainerFactory {
* @param trainParams The {@link TrainingParameters} to check for the
trainer type.
* Note: The entry {@link
TrainingParameters#ALGORITHM_PARAM} is used
* to determine the type.
- * @param reportMap A {@link Map} that shall be used during initialization of
- * the {@link SequenceTrainer}.
- *
+ * @param reportMap A {@link Map} that shall be used during initialization
of
+ * the {@link SequenceTrainer}.
* @return A valid {@link SequenceTrainer} for the configured {@code
trainParams}.
* @throws IllegalArgumentException Thrown if the trainer type could not be
determined.
*/
public static SequenceTrainer getSequenceModelTrainer(
- TrainingParameters trainParams, Map<String, String> reportMap) {
- String trainerType =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null);
+ TrainingParameters trainParams, Map<String, String> reportMap) {
+ String trainerType =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null);
if (trainerType != null) {
final SequenceTrainer trainer;
@@ -143,8 +138,7 @@ public class TrainerFactory {
}
trainer.init(trainParams, reportMap);
return trainer;
- }
- else {
+ } else {
throw new IllegalArgumentException("Trainer type couldn't be
determined!");
}
}
@@ -155,15 +149,14 @@ public class TrainerFactory {
* @param trainParams The {@link TrainingParameters} to check for the
trainer type.
* Note: The entry {@link
TrainingParameters#ALGORITHM_PARAM} is used
* to determine the type.
- * @param reportMap A {@link Map} that shall be used during initialization of
- * the {@link EventModelSequenceTrainer}.
- *
+ * @param reportMap A {@link Map} that shall be used during initialization
of
+ * the {@link EventModelSequenceTrainer}.
* @return A valid {@link EventModelSequenceTrainer} for the configured
{@code trainParams}.
* @throws IllegalArgumentException Thrown if the trainer type could not be
determined.
*/
public static <T> EventModelSequenceTrainer<T> getEventModelSequenceTrainer(
- TrainingParameters trainParams, Map<String, String> reportMap) {
- String trainerType =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM,null);
+ TrainingParameters trainParams, Map<String, String> reportMap) {
+ String trainerType =
trainParams.getStringParameter(TrainingParameters.ALGORITHM_PARAM, null);
if (trainerType != null) {
final EventModelSequenceTrainer<T> trainer;
@@ -174,12 +167,23 @@ public class TrainerFactory {
}
trainer.init(trainParams, reportMap);
return trainer;
- }
- else {
+ } else {
throw new IllegalArgumentException("Trainer type couldn't be
determined!");
}
}
+ /**
+ * Works just like {@link TrainerFactory#getEventTrainer(TrainingParameters,
Map, TrainingConfiguration)}
+ * except that {@link TrainingConfiguration} is initialized with default
values.
+ */
+ public static EventTrainer getEventTrainer(
+ TrainingParameters trainParams, Map<String, String> reportMap) {
+
+ TrainingConfiguration trainingConfiguration
+ = new TrainingConfiguration(new DefaultTrainingProgressMonitor(),
null);
+ return getEventTrainer(trainParams, reportMap, trainingConfiguration);
+ }
+
/**
* Retrieves an {@link EventTrainer} that fits the given parameters.
*
@@ -187,13 +191,13 @@ public class TrainerFactory {
* Note: The entry {@link
TrainingParameters#ALGORITHM_PARAM} is used
* to determine the type. If the type is not defined, the
* {@link GISTrainer#MAXENT_VALUE} will be used.
- * @param reportMap A {@link Map} that shall be used during initialization of
- * the {@link EventTrainer}.
- *
+ * @param reportMap A {@link Map} that shall be used during initialization
of
+ * the {@link EventTrainer}.
+ * @param config The {@link TrainingConfiguration} to be used.
* @return A valid {@link EventTrainer} for the configured {@code
trainParams}.
*/
public static EventTrainer getEventTrainer(
- TrainingParameters trainParams, Map<String, String> reportMap) {
+ TrainingParameters trainParams, Map<String, String> reportMap,
TrainingConfiguration config) {
// if the trainerType is not defined -- use the GISTrainer.
String trainerType = trainParams.getStringParameter(
@@ -205,7 +209,7 @@ public class TrainerFactory {
} else {
trainer = ExtensionLoader.instantiateExtension(EventTrainer.class,
trainerType);
}
- trainer.init(trainParams, reportMap);
+ trainer.init(trainParams, reportMap, config);
return trainer;
}
@@ -232,8 +236,7 @@ public class TrainerFactory {
TrainingParameters.CUTOFF_DEFAULT_VALUE);
trainParams.getIntParameter(TrainingParameters.ITERATIONS_PARAM,
TrainingParameters.ITERATIONS_DEFAULT_VALUE);
- }
- catch (NumberFormatException e) {
+ } catch (NumberFormatException e) {
return false;
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
index d2eabeb9..ae71d942 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
@@ -30,6 +30,7 @@ import java.util.concurrent.Executors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.DataIndexer;
@@ -40,7 +41,13 @@ import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Prior;
import opennlp.tools.ml.model.UniformPrior;
+import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
+import opennlp.tools.monitoring.LogLikelihoodThresholdBreached;
+import opennlp.tools.monitoring.StopCriteria;
+import opennlp.tools.monitoring.TrainingMeasure;
+import opennlp.tools.monitoring.TrainingProgressMonitor;
import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
/**
@@ -497,6 +504,11 @@ public class GISTrainer extends AbstractEventTrainer {
new ExecutorCompletionService<>(executor);
double prevLL = 0.0;
double currLL;
+
+ //Get the Training Progress Monitor and the StopCriteria.
+ TrainingProgressMonitor progressMonitor =
getTrainingProgressMonitor(trainingConfiguration);
+ StopCriteria stopCriteria = getStopCriteria(trainingConfiguration);
+
logger.info("Performing {} iterations.", iterations);
for (int i = 1; i <= iterations; i++) {
currLL = nextIteration(correctionConstant, completionService, i);
@@ -505,13 +517,20 @@ public class GISTrainer extends AbstractEventTrainer {
logger.warn("Model Diverging: loglikelihood decreased");
break;
}
- if (currLL - prevLL < llThreshold) {
+ if (stopCriteria.test(currLL - prevLL)) {
+ progressMonitor.finishedTraining(iterations, stopCriteria);
break;
}
}
prevLL = currLL;
}
+ //At this point, all iterations have finished successfully.
+ if (!progressMonitor.isTrainingFinished()) {
+ progressMonitor.finishedTraining(iterations, null);
+ }
+ progressMonitor.display(true);
+
// kill a bunch of these big objects now that we don't need them
observedExpects = null;
modelExpects = null;
@@ -628,8 +647,8 @@ public class GISTrainer extends AbstractEventTrainer {
}
}
- logger.info("{} - loglikelihood={}\t{}",
- iteration, loglikelihood, ((double) numCorrect / numEvents));
+ getTrainingProgressMonitor(trainingConfiguration).
+ finishedIteration(iteration, numCorrect, numEvents,
TrainingMeasure.LOG_LIKELIHOOD, loglikelihood);
return loglikelihood;
}
@@ -709,4 +728,31 @@ public class GISTrainer extends AbstractEventTrainer {
return loglikelihood;
}
}
+
+ /**
+ * Get the {@link StopCriteria} associated with this {@link Trainer}.
+ *
+ * @param trainingConfig {@link TrainingConfiguration}
+ * @return {@link StopCriteria}. If {@link TrainingConfiguration} is {@code
null} or
+ * {@link TrainingConfiguration#stopCriteria()} is {@code null},
+ * then return the default {@link StopCriteria}.
+ */
+ private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) {
+ return trainingConfig != null && trainingConfig.stopCriteria() != null
+ ? trainingConfig.stopCriteria() : new
LogLikelihoodThresholdBreached(trainingParameters);
+ }
+
+ /**
+ * Get the {@link TrainingProgressMonitor} associated with this {@link
Trainer}.
+ *
+ * @param trainingConfig {@link TrainingConfiguration}.
+ * @return {@link TrainingProgressMonitor}. If {@link TrainingConfiguration}
is {@code null} or
+ * {@link TrainingConfiguration#progMon()} is {@code null},
+ * then return the default {@link TrainingProgressMonitor}.
+ */
+ private TrainingProgressMonitor
getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
+ return trainingConfig != null && trainingConfig.progMon() != null ?
+ trainingConfig.progMon() : new DefaultTrainingProgressMonitor();
+ }
+
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
index d4d51b6c..9958e035 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java
@@ -22,12 +22,19 @@ import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import opennlp.tools.commons.Trainer;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
+import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
+import opennlp.tools.monitoring.IterDeltaAccuracyUnderTolerance;
+import opennlp.tools.monitoring.StopCriteria;
+import opennlp.tools.monitoring.TrainingMeasure;
+import opennlp.tools.monitoring.TrainingProgressMonitor;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
/**
@@ -293,6 +300,10 @@ public class PerceptronTrainer extends
AbstractEventTrainer {
}
}
+ //Get the Training Progress Monitor and the StopCriteria.
+ TrainingProgressMonitor progressMonitor =
getTrainingProgressMonitor(trainingConfiguration);
+ StopCriteria stopCriteria = getStopCriteria(trainingConfiguration);
+
// Keep track of the previous three accuracies. The difference of
// the mean of these and the current training set accuracy is used
// with tolerance to decide whether to stop.
@@ -349,10 +360,12 @@ public class PerceptronTrainer extends
AbstractEventTrainer {
}
}
- // Calculate the training accuracy and display.
+ // Calculate the training accuracy.
double trainingAccuracy = (double) numCorrect / numEvents;
- if (i < 10 || (i % 10) == 0)
- logger.info("{}: ({}/{}) {}", i, numCorrect, numEvents,
trainingAccuracy);
+ if (i < 10 || (i % 10) == 0) {
+ progressMonitor.finishedIteration(i, numCorrect, numEvents,
+ TrainingMeasure.ACCURACY, trainingAccuracy);
+ }
// TODO: Make averaging configurable !!!
@@ -370,10 +383,10 @@ public class PerceptronTrainer extends
AbstractEventTrainer {
// If the tolerance is greater than the difference between the
// current training accuracy and all of the previous three
// training accuracies, stop training.
- if (StrictMath.abs(prevAccuracy1 - trainingAccuracy) < tolerance
- && StrictMath.abs(prevAccuracy2 - trainingAccuracy) < tolerance
- && StrictMath.abs(prevAccuracy3 - trainingAccuracy) < tolerance) {
- logger.warn("Stopping: change in training set accuracy less than {}",
tolerance);
+ if (stopCriteria.test(prevAccuracy1 - trainingAccuracy)
+ && stopCriteria.test(prevAccuracy2 - trainingAccuracy)
+ && stopCriteria.test(prevAccuracy3 - trainingAccuracy)) {
+ progressMonitor.finishedTraining(iterations, stopCriteria);
break;
}
@@ -383,6 +396,12 @@ public class PerceptronTrainer extends
AbstractEventTrainer {
prevAccuracy3 = trainingAccuracy;
}
+ //At this point, all iterations have finished successfully.
+ if (!progressMonitor.isTrainingFinished()) {
+ progressMonitor.finishedTraining(iterations, null);
+ }
+ progressMonitor.display(true);
+
// Output the final training stats.
trainingStats(evalParams);
@@ -432,4 +451,30 @@ public class PerceptronTrainer extends
AbstractEventTrainer {
return root * root == n;
}
+ /**
+ * Get the {@link StopCriteria} associated with this {@link Trainer}.
+ *
+ * @param trainingConfig {@link TrainingConfiguration}
+ * @return {@link StopCriteria}. If {@link TrainingConfiguration} is {@code
null} or
+ * {@link TrainingConfiguration#stopCriteria()} is {@code null},
+ * then return the default {@link StopCriteria}.
+ */
+ private StopCriteria getStopCriteria(TrainingConfiguration trainingConfig) {
+ return trainingConfig != null && trainingConfig.stopCriteria() != null
+ ? trainingConfig.stopCriteria() : new
IterDeltaAccuracyUnderTolerance(trainingParameters);
+ }
+
+ /**
+ * Get the {@link TrainingProgressMonitor} associated with this {@link
Trainer}.
+ *
+ * @param trainingConfig {@link TrainingConfiguration}.
+ * @return {@link TrainingProgressMonitor}. If {@link TrainingConfiguration}
is {@code null} or
+ * {@link TrainingConfiguration#progMon()} is {@code null},
+ * then return the default {@link TrainingProgressMonitor}.
+ */
+ private TrainingProgressMonitor
getTrainingProgressMonitor(TrainingConfiguration trainingConfig) {
+ return trainingConfig != null && trainingConfig.progMon() != null ?
trainingConfig.progMon() :
+ new DefaultTrainingProgressMonitor();
+ }
+
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java
new file mode 100644
index 00000000..2d108251
--- /dev/null
+++
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitor.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static opennlp.tools.monitoring.StopCriteria.FINISHED;
+
+/**
+ * The default implementation of {@link TrainingProgressMonitor}.
+ * This publishes model training progress to the chosen logging destination.
+ */
+public class DefaultTrainingProgressMonitor implements TrainingProgressMonitor
{
+
+ private static final Logger logger =
LoggerFactory.getLogger(DefaultTrainingProgressMonitor.class);
+
+ /**
+ * Keeps a track of whether training was already finished.
+ */
+ private volatile boolean isTrainingFinished;
+
+ /**
+ * An underlying list to capture training progress events.
+ */
+ private final List<String> progress;
+
+ public DefaultTrainingProgressMonitor() {
+ this.progress = new LinkedList<>();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public synchronized void finishedIteration(int iteration, int
numberCorrectEvents, int totalEvents,
+ TrainingMeasure measure, double
measureValue) {
+ progress.add(String.format("%s: (%s/%s) %s : %s", iteration,
numberCorrectEvents, totalEvents,
+ measure.getMeasureName(), measureValue));
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public synchronized void finishedTraining(int iterations, StopCriteria
stopCriteria) {
+ if (!Objects.isNull(stopCriteria)) {
+ progress.add(stopCriteria.getMessageIfSatisfied());
+ } else {
+ progress.add(String.format(FINISHED, iterations));
+ }
+ isTrainingFinished = true;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public synchronized void display(boolean clear) {
+ progress.stream().forEach(logger::info);
+ if (clear) {
+ progress.clear();
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public boolean isTrainingFinished() {
+ return isTrainingFinished;
+ }
+}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java
new file mode 100644
index 00000000..958a27b3
--- /dev/null
+++
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderTolerance.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import opennlp.tools.ml.perceptron.PerceptronTrainer;
+import opennlp.tools.util.TrainingParameters;
+
+/**
+ * A {@link StopCriteria} implementation to identify whether the absolute
+ * difference between the training accuracy of current and previous iteration
is under the defined tolerance.
+ */
+public class IterDeltaAccuracyUnderTolerance implements StopCriteria<Double> {
+
+ public static final String STOP = "Stopping: change in training set accuracy
less than {%s}";
+ private final TrainingParameters trainingParameters;
+
+ public IterDeltaAccuracyUnderTolerance(TrainingParameters
trainingParameters) {
+ this.trainingParameters = trainingParameters;
+ }
+
+ @Override
+ public String getMessageIfSatisfied() {
+ return String.format(STOP, getTolerance());
+ }
+
+ @Override
+ public boolean test(Double deltaAccuracy) {
+ return StrictMath.abs(deltaAccuracy) < getTolerance();
+ }
+
+ private double getTolerance() {
+ return trainingParameters != null ?
trainingParameters.getDoubleParameter("Tolerance",
+ PerceptronTrainer.TOLERANCE_DEFAULT) :
PerceptronTrainer.TOLERANCE_DEFAULT;
+ }
+
+}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java
new file mode 100644
index 00000000..f6f34896
--- /dev/null
+++
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreached.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import opennlp.tools.util.TrainingParameters;
+
+import static
opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_DEFAULT;
+import static
opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_PARAM;
+
+/**
+ * A {@link StopCriteria} implementation to identify whether the
+ * difference between the log likelihood of current and previous iteration is
under the defined threshold.
+ */
+public class LogLikelihoodThresholdBreached implements StopCriteria<Double> {
+
+ public static String STOP = "Stopping: Difference between log likelihood of
current" +
+ " and previous iteration is less than threshold %s .";
+
+ private final TrainingParameters trainingParameters;
+
+ public LogLikelihoodThresholdBreached(TrainingParameters trainingParameters)
{
+ this.trainingParameters = trainingParameters;
+ }
+
+ @Override
+ public String getMessageIfSatisfied() {
+ return String.format(STOP, getThreshold());
+
+ }
+
+ @Override
+ public boolean test(Double currVsPrevLLDiff) {
+ return currVsPrevLLDiff < getThreshold();
+ }
+
+ private double getThreshold() {
+ return trainingParameters != null ?
trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM,
+ LOG_LIKELIHOOD_THRESHOLD_DEFAULT) : LOG_LIKELIHOOD_THRESHOLD_DEFAULT;
+ }
+
+}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java
similarity index 62%
copy from opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
copy to opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java
index 0d26ffbc..576aa7e5 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/StopCriteria.java
@@ -15,24 +15,26 @@
* limitations under the License.
*/
-package opennlp.tools.ml;
+package opennlp.tools.monitoring;
-import java.util.Map;
+import java.util.function.Predicate;
import opennlp.tools.ml.model.AbstractModel;
-import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.SequenceStream;
-import opennlp.tools.util.TrainingParameters;
-public class MockSequenceTrainer implements EventModelSequenceTrainer<Event> {
- @Override
- public AbstractModel train(SequenceStream<Event> events) {
- return null;
- }
+/**
+ * Stop criteria for model training. If the predicate is met, then the
training is aborted.
+ *
+ * @see Predicate
+ * @see AbstractModel
+ */
+public interface StopCriteria<T extends Number> extends Predicate<T> {
+
+ String FINISHED = "Training Finished after completing %s Iterations
successfully.";
+
+ /**
+ * @return A detailed message captured upon hitting the {@link StopCriteria}
during model training.
+ */
+ String getMessageIfSatisfied();
- @Override
- public void init(TrainingParameters trainParams, Map<String, String>
reportMap) {
- }
-
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java
similarity index 62%
copy from opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
copy to
opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java
index 0d26ffbc..9b61ed2e 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingMeasure.java
@@ -15,24 +15,22 @@
* limitations under the License.
*/
-package opennlp.tools.ml;
+package opennlp.tools.monitoring;
-import java.util.Map;
-
-import opennlp.tools.ml.model.AbstractModel;
-import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.SequenceStream;
-import opennlp.tools.util.TrainingParameters;
+/**
+ * Enumeration of Training measures.
+ */
+public enum TrainingMeasure {
+ ACCURACY("Training Accuracy"),
+ LOG_LIKELIHOOD("Log Likelihood");
-public class MockSequenceTrainer implements EventModelSequenceTrainer<Event> {
+ private String measureName;
- @Override
- public AbstractModel train(SequenceStream<Event> events) {
- return null;
+ TrainingMeasure(String measureName) {
+ this.measureName = measureName;
}
- @Override
- public void init(TrainingParameters trainParams, Map<String, String>
reportMap) {
+ public String getMeasureName() {
+ return measureName;
}
-
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java
new file mode 100644
index 00000000..be35b78a
--- /dev/null
+++
b/opennlp-tools/src/main/java/opennlp/tools/monitoring/TrainingProgressMonitor.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import opennlp.tools.ml.model.AbstractModel;
+
+/**
+ * An interface to capture training progress of an {@link AbstractModel}.
+ */
+
+public interface TrainingProgressMonitor {
+
+ /**
+ * Captures the Iteration progress.
+ *
+ * @param iteration The completed iteration number.
+ * @param numberCorrectEvents Number of correctly predicted events in this
iteration.
+ * @param totalEvents Total count of events processed in this
iteration.
+ * @param measure {@link TrainingMeasure}.
+ * @param measureValue measure value corresponding to the applicable
{@link TrainingMeasure}.
+ */
+ void finishedIteration(int iteration, int numberCorrectEvents, int
totalEvents,
+ TrainingMeasure measure, double measureValue);
+
+ /**
+ * Captures the training completion progress.
+ *
+ * @param iterations Total number of iterations configured for the
training.
+ * @param stopCriteria {@link StopCriteria} for the training.
+ */
+ void finishedTraining(int iterations, StopCriteria stopCriteria);
+
+ /**
+ * Checks whether the training has finished.
+ *
+ * @return {@code true} if the training has finished, {@code false} if the
training is not yet completed.
+ */
+ boolean isTrainingFinished();
+
+ /**
+ * Displays the training progress and optionally clears the recorded
progress (to save memory).
+ * Callers of this method can invoke it periodically
+ * during training, to avoid holding too much progress related data in
memory.
+ *
+ * @param clear Set to true to clear the recorded progress.
+ */
+ void display(boolean clear);
+}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
b/opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java
similarity index 63%
copy from opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
copy to
opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java
index 0d26ffbc..f3e05cdc 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/util/TrainingConfiguration.java
@@ -15,24 +15,15 @@
* limitations under the License.
*/
-package opennlp.tools.ml;
-
-import java.util.Map;
+package opennlp.tools.util;
import opennlp.tools.ml.model.AbstractModel;
-import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.SequenceStream;
-import opennlp.tools.util.TrainingParameters;
-
-public class MockSequenceTrainer implements EventModelSequenceTrainer<Event> {
+import opennlp.tools.monitoring.StopCriteria;
+import opennlp.tools.monitoring.TrainingProgressMonitor;
- @Override
- public AbstractModel train(SequenceStream<Event> events) {
- return null;
- }
-
- @Override
- public void init(TrainingParameters trainParams, Map<String, String>
reportMap) {
- }
-
-}
+/**
+ * Configuration used for {@link AbstractModel} training.
+ * @param progMon {@link TrainingProgressMonitor} used to monitor the training
progress.
+ * @param stopCriteria {@link StopCriteria} used to abort training when the
criteria is met.
+ */
+public record TrainingConfiguration(TrainingProgressMonitor progMon,
StopCriteria stopCriteria) {}
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
index 7a7b6383..56903c1b 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockEventTrainer.java
@@ -23,6 +23,7 @@ import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
public class MockEventTrainer implements EventTrainer {
@@ -39,4 +40,10 @@ public class MockEventTrainer implements EventTrainer {
@Override
public void init(TrainingParameters trainingParams, Map<String, String>
reportMap) {
}
+
+ @Override
+ public void init(TrainingParameters trainParams, Map<String, String>
reportMap,
+ TrainingConfiguration config) {
+ }
+
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
index 0d26ffbc..26e65b50 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/MockSequenceTrainer.java
@@ -22,6 +22,7 @@ import java.util.Map;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.SequenceStream;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
public class MockSequenceTrainer implements EventModelSequenceTrainer<Event> {
@@ -34,5 +35,10 @@ public class MockSequenceTrainer implements
EventModelSequenceTrainer<Event> {
@Override
public void init(TrainingParameters trainParams, Map<String, String>
reportMap) {
}
-
+
+ @Override
+ public void init(TrainingParameters trainParams, Map<String, String>
reportMap,
+ TrainingConfiguration config) {
+ }
+
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
index a8f1224a..72388b4e 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
@@ -24,8 +24,14 @@ import org.junit.jupiter.api.Test;
import opennlp.tools.ml.TrainerFactory.TrainerType;
import opennlp.tools.ml.maxent.GISTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
+import opennlp.tools.monitoring.DefaultTrainingProgressMonitor;
+import opennlp.tools.monitoring.LogLikelihoodThresholdBreached;
+import opennlp.tools.util.TrainingConfiguration;
import opennlp.tools.util.TrainingParameters;
+import static org.junit.jupiter.api.Assertions.assertAll;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
public class TrainerFactoryTest {
private TrainingParameters mlParams;
@@ -78,4 +84,18 @@ public class TrainerFactoryTest {
Assertions.assertNotEquals(TrainerType.EVENT_MODEL_SEQUENCE_TRAINER,
trainerType);
}
+ @Test
+ void testGetEventTrainerConfiguration() {
+ mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
+
+ TrainingConfiguration config = new TrainingConfiguration(new
DefaultTrainingProgressMonitor(),
+ new LogLikelihoodThresholdBreached(mlParams));
+
+ AbstractTrainer trainer = (AbstractTrainer)
TrainerFactory.getEventTrainer(mlParams, null, config);
+
+ assertAll(() -> assertTrue(trainer.getTrainingConfiguration().progMon()
instanceof
+ DefaultTrainingProgressMonitor),
+ () -> assertTrue(trainer.getTrainingConfiguration().stopCriteria()
instanceof
+ LogLikelihoodThresholdBreached));
+ }
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java
new file mode 100644
index 00000000..59c68a6a
--- /dev/null
+++
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/DefaultTrainingProgressMonitorTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import java.util.List;
+import java.util.Map;
+
+import ch.qos.logback.classic.Level;
+import ch.qos.logback.classic.Logger;
+import ch.qos.logback.classic.spi.ILoggingEvent;
+import ch.qos.logback.core.read.ListAppender;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.slf4j.LoggerFactory;
+
+import opennlp.tools.util.TrainingParameters;
+
+import static java.util.stream.Collectors.toList;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class DefaultTrainingProgressMonitorTest {
+
+ private static final String LOGGER_NAME = "opennlp";
+ private static final Logger logger = (Logger)
LoggerFactory.getLogger(LOGGER_NAME);
+ private static final Level originalLogLevel = logger != null ?
logger.getLevel() : Level.OFF;
+
+ private TrainingProgressMonitor progressMonitor;
+ private final ListAppender<ILoggingEvent> appender = new ListAppender<>();
+
+
+ @BeforeAll
+ static void beforeAll() {
+ logger.setLevel(Level.INFO);
+ }
+
+ @BeforeEach
+ public void setup() {
+ progressMonitor = new DefaultTrainingProgressMonitor();
+ appender.list.clear();
+ logger.addAppender(appender);
+ appender.start();
+ }
+
+ @Test
+ void testFinishedIteration() {
+ progressMonitor.finishedIteration(1, 19830, 20801,
TrainingMeasure.ACCURACY, 0.953319551944618);
+ progressMonitor.finishedIteration(2, 19852, 20801,
TrainingMeasure.ACCURACY, 0.9543771934041633);
+ progressMonitor.display(true);
+
+ //Assert that two logging events are captured for two iterations.
+ List<String> actual =
appender.list.stream().map(ILoggingEvent::getMessage).
+ collect(toList());
+ List<String> expected = List.of("1: (19830/20801) Training Accuracy :
0.953319551944618",
+ "2: (19852/20801) Training Accuracy : 0.9543771934041633");
+ assertArrayEquals(expected.toArray(), actual.toArray());
+
+ }
+
+ @Test
+ void testFinishedTrainingWithStopCriteria() {
+ StopCriteria stopCriteria = new IterDeltaAccuracyUnderTolerance(new
TrainingParameters(Map.of("Tolerance",
+ .00002)));
+ progressMonitor.finishedTraining(150, stopCriteria);
+ progressMonitor.display(true);
+
+ //Assert that the logs captured the training completion message with
StopCriteria satisfied.
+ List<String> actual =
appender.list.stream().map(ILoggingEvent::getMessage).
+ collect(toList());
+ List<String> expected = List.of("Stopping: change in training set accuracy
less than {2.0E-5}");
+ assertArrayEquals(expected.toArray(), actual.toArray());
+ }
+
+ @Test
+ void testFinishedTrainingWithoutStopCriteria() {
+ progressMonitor.finishedTraining(150, null);
+ progressMonitor.display(true);
+
+ //Assert that the logs captured the training completion message when all
iterations are exhausted.
+ List<String> actual =
appender.list.stream().map(ILoggingEvent::getMessage).
+ collect(toList());
+ List<String> expected = List.of("Training Finished after completing 150
Iterations successfully.");
+ assertArrayEquals(expected.toArray(), actual.toArray());
+ }
+
+ @Test
+ void displayAndClear() {
+ progressMonitor.finishedTraining(150, null);
+ progressMonitor.display(true);
+
+ //Assert that the previous invocation of display has cleared the recorded
training progress.
+ appender.list.clear();
+ progressMonitor.display(true);
+ assertEquals(0, appender.list.size());
+ }
+
+ @Test
+ void displayAndKeep() {
+ progressMonitor.finishedTraining(150, null);
+ progressMonitor.display(false);
+
+ //Assert that the previous invocation of display has not cleared the
recorded training progress.
+ progressMonitor.display(false);
+ assertEquals(2, appender.list.size());
+ }
+
+ @AfterAll
+ static void afterAll() {
+ logger.setLevel(originalLogLevel);
+ }
+}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java
new file mode 100644
index 00000000..4ca7c2eb
--- /dev/null
+++
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/IterDeltaAccuracyUnderToleranceTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import java.util.Map;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+
+import opennlp.tools.util.TrainingParameters;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class IterDeltaAccuracyUnderToleranceTest {
+
+ private StopCriteria stopCriteria;
+
+ @BeforeEach
+ public void setup() {
+ stopCriteria = new IterDeltaAccuracyUnderTolerance(new
TrainingParameters(Map.of("Tolerance",
+ .00002)));
+ }
+
+ @ParameterizedTest
+ @CsvSource( {"0.01,false", "-0.01,false", "0.00001,true", "-0.00001,true"})
+ void testCriteria(double val, String expectedVal) {
+ assertEquals(Boolean.parseBoolean(expectedVal), stopCriteria.test(val));
+ }
+
+ @Test
+ void testMessageIfSatisfied() {
+ assertEquals("Stopping: change in training set accuracy less than
{2.0E-5}",
+ stopCriteria.getMessageIfSatisfied());
+ }
+}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java
new file mode 100644
index 00000000..7786847b
--- /dev/null
+++
b/opennlp-tools/src/test/java/opennlp/tools/monitoring/LogLikelihoodThresholdBreachedTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.monitoring;
+
+import java.util.Map;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+
+import opennlp.tools.util.TrainingParameters;
+
+import static
opennlp.tools.ml.maxent.GISTrainer.LOG_LIKELIHOOD_THRESHOLD_PARAM;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
+class LogLikelihoodThresholdBreachedTest {
+
+ private StopCriteria stopCriteria;
+
+ @BeforeEach
+ public void setup() {
+ stopCriteria = new LogLikelihoodThresholdBreached(
+ new TrainingParameters(Map.of(LOG_LIKELIHOOD_THRESHOLD_PARAM,5.)));
+ }
+
+ @ParameterizedTest
+ @CsvSource({"0.01,true", "-0.01,true", "6.0,false", "-6.0,true"})
+ void testCriteria(double val, String expectedVal) {
+ assertEquals(Boolean.parseBoolean(expectedVal), stopCriteria.test(val));
+ }
+
+ @Test
+ void testMessageIfSatisfied() {
+ assertEquals("Stopping: Difference between log likelihood of current" +
+ " and previous iteration is less than threshold 5.0 .",
+ stopCriteria.getMessageIfSatisfied());
+ }
+
+}