Repository: opennlp Updated Branches: refs/heads/master 89f9e55b7 -> 6ffdfbb8c
OPENNLP-946: GISTrainer should extend AbstractEventTrainer This closes #104 Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/6ffdfbb8 Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/6ffdfbb8 Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/6ffdfbb8 Branch: refs/heads/master Commit: 6ffdfbb8c97d7f70e5cb57be566f1c90e35e03ba Parents: 89f9e55 Author: smarthi <[email protected]> Authored: Mon Jan 30 13:57:55 2017 -0500 Committer: Jörn Kottmann <[email protected]> Committed: Tue Jan 31 12:36:50 2017 +0100 ---------------------------------------------------------------------- .../java/opennlp/tools/ml/AbstractTrainer.java | 6 +- .../java/opennlp/tools/ml/TrainerFactory.java | 17 +++-- .../main/java/opennlp/tools/ml/maxent/GIS.java | 2 + .../opennlp/tools/ml/maxent/GISTrainer.java | 72 ++++++++++++++++++-- .../opennlp/tools/util/model/ModelUtil.java | 4 +- .../opennlp/tools/ml/TrainerFactoryTest.java | 9 +-- .../tools/ml/maxent/GISIndexingTest.java | 2 +- .../tools/ml/maxent/MaxentPrepAttachTest.java | 4 +- .../tools/ml/maxent/RealValueModelTest.java | 5 +- 9 files changed, 93 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java index 6de81ef..070b96c 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java @@ -20,7 +20,7 @@ package opennlp.tools.ml; import java.util.HashMap; import java.util.Map; -import opennlp.tools.ml.maxent.GIS; +import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.util.TrainingParameters; public abstract class AbstractTrainer { @@ -63,7 +63,7 @@ public abstract class AbstractTrainer { } public String getAlgorithm() { - return trainingParameters.getStringParameter(ALGORITHM_PARAM, GIS.MAXENT_VALUE); + return trainingParameters.getStringParameter(ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); } public int getCutoff() { @@ -123,7 +123,7 @@ public abstract class AbstractTrainer { /** * Use the PluggableParameters directly... * @param key - * @param value + * @param defaultValue */ @Deprecated protected boolean getBooleanParam(String key, boolean defaultValue) { http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java index f825491..7897cf2 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java @@ -18,12 +18,11 @@ package opennlp.tools.ml; import java.lang.reflect.Constructor; - import java.util.Collections; import java.util.HashMap; import java.util.Map; -import opennlp.tools.ml.maxent.GIS; +import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.ml.maxent.quasinewton.QNTrainer; import opennlp.tools.ml.naivebayes.NaiveBayesTrainer; import opennlp.tools.ml.perceptron.PerceptronTrainer; @@ -44,7 +43,7 @@ public class TrainerFactory { static { Map<String, Class> _trainers = new HashMap<>(); - _trainers.put(GIS.MAXENT_VALUE, GIS.class); + _trainers.put(GISTrainer.MAXENT_VALUE, GISTrainer.class); _trainers.put(QNTrainer.MAXENT_QN_VALUE, QNTrainer.class); _trainers.put(PerceptronTrainer.PERCEPTRON_VALUE, PerceptronTrainer.class); _trainers.put(SimplePerceptronSequenceTrainer.PERCEPTRON_SEQUENCE_VALUE, @@ -57,7 +56,7 @@ public class TrainerFactory { /** * Determines the trainer type based on the ALGORITHM_PARAM value. * - * @param trainParams + * @param trainParams - Map of training parameters * @return the trainer type or null if type couldn't be determined. */ public static TrainerType getTrainerType(Map<String, String> trainParams) { @@ -161,7 +160,7 @@ public class TrainerFactory { String trainerType = trainParams.get(AbstractTrainer.ALGORITHM_PARAM); if (trainerType == null) { // default to MAXENT - AbstractEventTrainer trainer = new GIS(); + AbstractEventTrainer trainer = new GISTrainer(); trainer.init(trainParams, reportMap); return trainer; } @@ -193,10 +192,14 @@ public class TrainerFactory { try { String cutoffString = trainParams.get(AbstractTrainer.CUTOFF_PARAM); - if (cutoffString != null) Integer.parseInt(cutoffString); + if (cutoffString != null) { + Integer.parseInt(cutoffString); + } String iterationsString = trainParams.get(AbstractTrainer.ITERATIONS_PARAM); - if (iterationsString != null) Integer.parseInt(iterationsString); + if (iterationsString != null) { + Integer.parseInt(iterationsString); + } } catch (NumberFormatException e) { return false; http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java index cdde6eb..97c214d 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java @@ -31,7 +31,9 @@ import opennlp.tools.util.TrainingParameters; /** * A Factory class which uses instances of GISTrainer to create and train * GISModels. + * @deprecated use {@link GISTrainer} */ +@Deprecated public class GIS extends AbstractEventTrainer { public static final String MAXENT_VALUE = "MAXENT"; http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java index 34b640b..7e220c3 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java @@ -27,9 +27,12 @@ import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import opennlp.tools.ml.AbstractEventTrainer; +import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.DataIndexer; import opennlp.tools.ml.model.EvalParameters; import opennlp.tools.ml.model.Event; +import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.ml.model.MutableContext; import opennlp.tools.ml.model.OnePassDataIndexer; import opennlp.tools.ml.model.Prior; @@ -55,7 +58,7 @@ import opennlp.tools.util.TrainingParameters; * relative entropy between the distribution specified by the empirical constraints of the training * data and the specified prior. By default, the uniform distribution is used as the prior. */ -public class GISTrainer { +public class GISTrainer extends AbstractEventTrainer { private static final double LLThreshold = 0.0001; private final boolean printMessages; @@ -134,14 +137,46 @@ public class GISTrainer { */ private EvalParameters evalParams; + public static final String MAXENT_VALUE = "MAXENT"; + + /** + * If we are using smoothing, this is used as the "number" of times we want + * the trainer to imagine that it saw a feature that it actually didn't see. + * Defaulted to 0.1. + */ + private static final double SMOOTHING_OBSERVATION = 0.1; + + private static final String SMOOTHING_PARAM = "smoothing"; + private static final boolean SMOOTHING_DEFAULT = false; + /** * Creates a new <code>GISTrainer</code> instance which does not print * progress messages about training to STDOUT. */ - GISTrainer() { + public GISTrainer() { printMessages = false; } + @Override + public boolean isSortAndMerge() { + return true; + } + + @Override + public MaxentModel doTrain(DataIndexer indexer) throws IOException { + int iterations = getIterations(); + + AbstractModel model; + + boolean smoothing = trainingParameters.getBooleanParameter(SMOOTHING_PARAM, SMOOTHING_DEFAULT); + int threads = trainingParameters.getIntParameter(TrainingParameters.THREADS_PARAM, 1); + + this.setSmoothing(smoothing); + model = trainModel(iterations, indexer, threads); + + return model; + } + /** * Creates a new <code>GISTrainer</code> instance. * @@ -186,6 +221,20 @@ public class GISTrainer { } /** + * Train a model using the GIS algorithm, assuming 100 iterations and no + * cutoff. + * + * @param eventStream + * The EventStream holding the data on which this model will be + * trained. + * @return The newly trained model, which can be used immediately or saved to + * disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. + */ + public GISModel trainModel(ObjectStream<Event> eventStream) throws IOException { + return trainModel(eventStream, 100, 0); + } + + /** * Trains a GIS model on the event in the specified event stream, using the specified number * of iterations and the specified count cutoff. * @@ -198,8 +247,8 @@ public class GISTrainer { int cutoff) throws IOException { DataIndexer indexer = new OnePassDataIndexer(); TrainingParameters indexingParameters = new TrainingParameters(); - indexingParameters.put(GIS.CUTOFF_PARAM, Integer.toString(cutoff)); - indexingParameters.put(GIS.ITERATIONS_PARAM, Integer.toString(iterations)); + indexingParameters.put(GISTrainer.CUTOFF_PARAM, Integer.toString(cutoff)); + indexingParameters.put(GISTrainer.ITERATIONS_PARAM, Integer.toString(iterations)); Map<String, String> reportMap = new HashMap<>(); indexer.init(indexingParameters, reportMap); indexer.index(eventStream); @@ -223,6 +272,19 @@ public class GISTrainer { * * @param iterations The number of GIS iterations to perform. * @param di The data indexer used to compress events in memory. + * @param threads + * @return The newly trained model, which can be used immediately or saved + * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. + */ + public GISModel trainModel(int iterations, DataIndexer di, int threads) { + return trainModel(iterations, di, new UniformPrior(), threads); + } + + /** + * Train a model using the GIS algorithm. + * + * @param iterations The number of GIS iterations to perform. + * @param di The data indexer used to compress events in memory. * @param modelPrior The prior distribution used to train this model. * @return The newly trained model, which can be used immediately or saved * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. @@ -529,7 +591,7 @@ public class GISTrainer { return loglikelihood; } - private void display(String s) { + protected void display(String s) { if (printMessages) { System.out.print(s); } http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java b/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java index ab5687f..85f6e12 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java +++ b/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java @@ -28,7 +28,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import opennlp.tools.ml.maxent.GIS; +import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.GenericModelWriter; import opennlp.tools.ml.model.MaxentModel; @@ -141,7 +141,7 @@ public final class ModelUtil { */ public static TrainingParameters createDefaultTrainingParameters() { TrainingParameters mlParams = new TrainingParameters(); - mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE); + mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(100)); mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5)); http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java index 01482f3..092742c 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java @@ -22,7 +22,7 @@ import org.junit.Before; import org.junit.Test; import opennlp.tools.ml.TrainerFactory.TrainerType; -import opennlp.tools.ml.maxent.GIS; +import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer; import opennlp.tools.util.TrainingParameters; @@ -33,7 +33,7 @@ public class TrainerFactoryTest { @Before public void setup() { mlParams = new TrainingParameters(); - mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE); + mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(10)); mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5)); } @@ -73,11 +73,8 @@ public class TrainerFactoryTest { @Test public void testIsSequenceTrainerFalse() { - mlParams.put(AbstractTrainer.ALGORITHM_PARAM, - GIS.MAXENT_VALUE); - + mlParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); TrainerType trainerType = TrainerFactory.getTrainerType(mlParams.getSettings()); - Assert.assertFalse(TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)); } http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java index 771d3a5..5a98f73 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java @@ -120,7 +120,7 @@ public class GISIndexingTest { // guarantee that you have a GIS trainer... EventTrainer trainer = TrainerFactory.getEventTrainer(parameters.getSettings(), new HashMap<>()); - Assert.assertEquals("opennlp.tools.ml.maxent.GIS", trainer.getClass().getName()); + Assert.assertEquals("opennlp.tools.ml.maxent.GISTrainer", trainer.getClass().getName()); AbstractEventTrainer aeTrainer = (AbstractEventTrainer)trainer; // guarantee that you have a OnePassDataIndexer ... DataIndexer di = aeTrainer.getDataIndexer(eventStream); http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java index b531b33..74b13de 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java @@ -76,7 +76,7 @@ public class MaxentPrepAttachTest { public void testMaxentOnPrepAttachDataWithParams() throws IOException { Map<String, String> trainParams = new HashMap<>(); - trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE); + trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); trainParams.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE); trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1)); @@ -91,7 +91,7 @@ public class MaxentPrepAttachTest { public void testMaxentOnPrepAttachDataWithParamsDefault() throws IOException { Map<String, String> trainParams = new HashMap<>(); - trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE); + trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null); MaxentModel model = trainer.train(PrepAttachDataUtil.createTrainingStream()); http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java index 431ea96..850d9bc 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java @@ -46,17 +46,18 @@ public class RealValueModelTest { @Test public void testRealValuedWeightsVsRepeatWeighting() throws IOException { GISModel realModel; + GISTrainer gisTrainer = new GISTrainer(); try (RealValueFileEventStream rvfes1 = new RealValueFileEventStream( "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt")) { testDataIndexer.index(rvfes1); - realModel = GIS.trainModel(100, testDataIndexer); + realModel = gisTrainer.trainModel(100, testDataIndexer); } GISModel repeatModel; try (FileEventStream rvfes2 = new FileEventStream( "src/test/resources/data/opennlp/maxent/repeat-weighting-training-data.txt")) { testDataIndexer.index(rvfes2); - repeatModel = GIS.trainModel(100,testDataIndexer); + repeatModel = gisTrainer.trainModel(100,testDataIndexer); } String[] features2Classify = new String[] {"feature2","feature5"};
