Repository: opennlp
Updated Branches:
  refs/heads/master 89f9e55b7 -> 6ffdfbb8c


OPENNLP-946: GISTrainer should extend AbstractEventTrainer

This closes #104


Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo
Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/6ffdfbb8
Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/6ffdfbb8
Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/6ffdfbb8

Branch: refs/heads/master
Commit: 6ffdfbb8c97d7f70e5cb57be566f1c90e35e03ba
Parents: 89f9e55
Author: smarthi <[email protected]>
Authored: Mon Jan 30 13:57:55 2017 -0500
Committer: Jörn Kottmann <[email protected]>
Committed: Tue Jan 31 12:36:50 2017 +0100

----------------------------------------------------------------------
 .../java/opennlp/tools/ml/AbstractTrainer.java  |  6 +-
 .../java/opennlp/tools/ml/TrainerFactory.java   | 17 +++--
 .../main/java/opennlp/tools/ml/maxent/GIS.java  |  2 +
 .../opennlp/tools/ml/maxent/GISTrainer.java     | 72 ++++++++++++++++++--
 .../opennlp/tools/util/model/ModelUtil.java     |  4 +-
 .../opennlp/tools/ml/TrainerFactoryTest.java    |  9 +--
 .../tools/ml/maxent/GISIndexingTest.java        |  2 +-
 .../tools/ml/maxent/MaxentPrepAttachTest.java   |  4 +-
 .../tools/ml/maxent/RealValueModelTest.java     |  5 +-
 9 files changed, 93 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
----------------------------------------------------------------------
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
index 6de81ef..070b96c 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java
@@ -20,7 +20,7 @@ package opennlp.tools.ml;
 import java.util.HashMap;
 import java.util.Map;
 
-import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.GISTrainer;
 import opennlp.tools.util.TrainingParameters;
 
 public abstract class AbstractTrainer {
@@ -63,7 +63,7 @@ public abstract class AbstractTrainer {
   }
 
   public String getAlgorithm() {
-    return trainingParameters.getStringParameter(ALGORITHM_PARAM, 
GIS.MAXENT_VALUE);
+    return trainingParameters.getStringParameter(ALGORITHM_PARAM, 
GISTrainer.MAXENT_VALUE);
   }
 
   public int getCutoff() {
@@ -123,7 +123,7 @@ public abstract class AbstractTrainer {
   /**
    * Use the PluggableParameters directly...
    * @param key
-   * @param value
+   * @param defaultValue
    */
   @Deprecated
   protected boolean getBooleanParam(String key, boolean defaultValue) {

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
----------------------------------------------------------------------
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
index f825491..7897cf2 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
@@ -18,12 +18,11 @@
 package opennlp.tools.ml;
 
 import java.lang.reflect.Constructor;
-
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
-import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.GISTrainer;
 import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
 import opennlp.tools.ml.naivebayes.NaiveBayesTrainer;
 import opennlp.tools.ml.perceptron.PerceptronTrainer;
@@ -44,7 +43,7 @@ public class TrainerFactory {
 
   static {
     Map<String, Class> _trainers = new HashMap<>();
-    _trainers.put(GIS.MAXENT_VALUE, GIS.class);
+    _trainers.put(GISTrainer.MAXENT_VALUE, GISTrainer.class);
     _trainers.put(QNTrainer.MAXENT_QN_VALUE, QNTrainer.class);
     _trainers.put(PerceptronTrainer.PERCEPTRON_VALUE, PerceptronTrainer.class);
     _trainers.put(SimplePerceptronSequenceTrainer.PERCEPTRON_SEQUENCE_VALUE,
@@ -57,7 +56,7 @@ public class TrainerFactory {
   /**
    * Determines the trainer type based on the ALGORITHM_PARAM value.
    *
-   * @param trainParams
+   * @param trainParams - Map of training parameters
    * @return the trainer type or null if type couldn't be determined.
    */
   public static TrainerType getTrainerType(Map<String, String> trainParams) {
@@ -161,7 +160,7 @@ public class TrainerFactory {
     String trainerType = trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
     if (trainerType == null) {
       // default to MAXENT
-      AbstractEventTrainer trainer = new GIS();
+      AbstractEventTrainer trainer = new GISTrainer();
       trainer.init(trainParams, reportMap);
       return trainer;
     }
@@ -193,10 +192,14 @@ public class TrainerFactory {
 
     try {
       String cutoffString = trainParams.get(AbstractTrainer.CUTOFF_PARAM);
-      if (cutoffString != null) Integer.parseInt(cutoffString);
+      if (cutoffString != null) {
+        Integer.parseInt(cutoffString);
+      }
 
       String iterationsString = 
trainParams.get(AbstractTrainer.ITERATIONS_PARAM);
-      if (iterationsString != null) Integer.parseInt(iterationsString);
+      if (iterationsString != null) {
+        Integer.parseInt(iterationsString);
+      }
     }
     catch (NumberFormatException e) {
       return false;

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
----------------------------------------------------------------------
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
index cdde6eb..97c214d 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java
@@ -31,7 +31,9 @@ import opennlp.tools.util.TrainingParameters;
 /**
  * A Factory class which uses instances of GISTrainer to create and train
  * GISModels.
+ * @deprecated use {@link GISTrainer}
  */
+@Deprecated
 public class GIS extends AbstractEventTrainer {
 
   public static final String MAXENT_VALUE = "MAXENT";

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java 
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
index 34b640b..7e220c3 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java
@@ -27,9 +27,12 @@ import java.util.concurrent.ExecutorCompletionService;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.DataIndexer;
 import opennlp.tools.ml.model.EvalParameters;
 import opennlp.tools.ml.model.Event;
+import opennlp.tools.ml.model.MaxentModel;
 import opennlp.tools.ml.model.MutableContext;
 import opennlp.tools.ml.model.OnePassDataIndexer;
 import opennlp.tools.ml.model.Prior;
@@ -55,7 +58,7 @@ import opennlp.tools.util.TrainingParameters;
  * relative entropy between the distribution specified by the empirical 
constraints of the training
  * data and the specified prior.  By default, the uniform distribution is used 
as the prior.
  */
-public class GISTrainer {
+public class GISTrainer extends AbstractEventTrainer {
 
   private static final double LLThreshold = 0.0001;
   private final boolean printMessages;
@@ -134,14 +137,46 @@ public class GISTrainer {
    */
   private EvalParameters evalParams;
 
+  public static final String MAXENT_VALUE = "MAXENT";
+
+  /**
+   * If we are using smoothing, this is used as the "number" of times we want
+   * the trainer to imagine that it saw a feature that it actually didn't see.
+   * Defaulted to 0.1.
+   */
+  private static final double SMOOTHING_OBSERVATION = 0.1;
+
+  private static final String SMOOTHING_PARAM = "smoothing";
+  private static final boolean SMOOTHING_DEFAULT = false;
+
   /**
    * Creates a new <code>GISTrainer</code> instance which does not print
    * progress messages about training to STDOUT.
    */
-  GISTrainer() {
+  public GISTrainer() {
     printMessages = false;
   }
 
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+
+  @Override
+  public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+    int iterations = getIterations();
+
+    AbstractModel model;
+
+    boolean smoothing = 
trainingParameters.getBooleanParameter(SMOOTHING_PARAM, SMOOTHING_DEFAULT);
+    int threads = 
trainingParameters.getIntParameter(TrainingParameters.THREADS_PARAM, 1);
+
+    this.setSmoothing(smoothing);
+    model = trainModel(iterations, indexer, threads);
+
+    return model;
+  }
+
   /**
    * Creates a new <code>GISTrainer</code> instance.
    *
@@ -186,6 +221,20 @@ public class GISTrainer {
   }
 
   /**
+   * Train a model using the GIS algorithm, assuming 100 iterations and no
+   * cutoff.
+   *
+   * @param eventStream
+   *          The EventStream holding the data on which this model will be
+   *          trained.
+   * @return The newly trained model, which can be used immediately or saved to
+   *         disk using an opennlp.tools.ml.maxent.io.GISModelWriter object.
+   */
+  public GISModel trainModel(ObjectStream<Event> eventStream) throws 
IOException {
+    return trainModel(eventStream, 100, 0);
+  }
+
+  /**
    * Trains a GIS model on the event in the specified event stream, using the 
specified number
    * of iterations and the specified count cutoff.
    *
@@ -198,8 +247,8 @@ public class GISTrainer {
                              int cutoff) throws IOException {
     DataIndexer indexer = new OnePassDataIndexer();
     TrainingParameters indexingParameters = new TrainingParameters();
-    indexingParameters.put(GIS.CUTOFF_PARAM, Integer.toString(cutoff));
-    indexingParameters.put(GIS.ITERATIONS_PARAM, Integer.toString(iterations));
+    indexingParameters.put(GISTrainer.CUTOFF_PARAM, Integer.toString(cutoff));
+    indexingParameters.put(GISTrainer.ITERATIONS_PARAM, 
Integer.toString(iterations));
     Map<String, String> reportMap = new HashMap<>();
     indexer.init(indexingParameters, reportMap);
     indexer.index(eventStream);
@@ -223,6 +272,19 @@ public class GISTrainer {
    *
    * @param iterations The number of GIS iterations to perform.
    * @param di         The data indexer used to compress events in memory.
+   * @param threads
+   * @return The newly trained model, which can be used immediately or saved
+   * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object.
+   */
+  public GISModel trainModel(int iterations, DataIndexer di, int threads) {
+    return trainModel(iterations, di, new UniformPrior(), threads);
+  }
+
+  /**
+   * Train a model using the GIS algorithm.
+   *
+   * @param iterations The number of GIS iterations to perform.
+   * @param di         The data indexer used to compress events in memory.
    * @param modelPrior The prior distribution used to train this model.
    * @return The newly trained model, which can be used immediately or saved
    * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object.
@@ -529,7 +591,7 @@ public class GISTrainer {
     return loglikelihood;
   }
 
-  private void display(String s) {
+  protected void display(String s) {
     if (printMessages) {
       System.out.print(s);
     }

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java 
b/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
index ab5687f..85f6e12 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/util/model/ModelUtil.java
@@ -28,7 +28,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
-import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.GISTrainer;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.GenericModelWriter;
 import opennlp.tools.ml.model.MaxentModel;
@@ -141,7 +141,7 @@ public final class ModelUtil {
    */
   public static TrainingParameters createDefaultTrainingParameters() {
     TrainingParameters mlParams = new TrainingParameters();
-    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
     mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(100));
     mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5));
 

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
index 01482f3..092742c 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/TrainerFactoryTest.java
@@ -22,7 +22,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import opennlp.tools.ml.TrainerFactory.TrainerType;
-import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.GISTrainer;
 import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
 import opennlp.tools.util.TrainingParameters;
 
@@ -33,7 +33,7 @@ public class TrainerFactoryTest {
   @Before
   public void setup() {
     mlParams = new TrainingParameters();
-    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
     mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(10));
     mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5));
   }
@@ -73,11 +73,8 @@ public class TrainerFactoryTest {
 
   @Test
   public void testIsSequenceTrainerFalse() {
-    mlParams.put(AbstractTrainer.ALGORITHM_PARAM,
-        GIS.MAXENT_VALUE);
-
+    mlParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
     TrainerType trainerType = 
TrainerFactory.getTrainerType(mlParams.getSettings());
-
     
Assert.assertFalse(TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType));
   }
 

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java
index 771d3a5..5a98f73 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISIndexingTest.java
@@ -120,7 +120,7 @@ public class GISIndexingTest {
     // guarantee that you have a GIS trainer...
     EventTrainer trainer =
         TrainerFactory.getEventTrainer(parameters.getSettings(), new 
HashMap<>());
-    Assert.assertEquals("opennlp.tools.ml.maxent.GIS", 
trainer.getClass().getName());
+    Assert.assertEquals("opennlp.tools.ml.maxent.GISTrainer", 
trainer.getClass().getName());
     AbstractEventTrainer aeTrainer = (AbstractEventTrainer)trainer;
     // guarantee that you have a OnePassDataIndexer ...
     DataIndexer di = aeTrainer.getDataIndexer(eventStream);

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
index b531b33..74b13de 100644
--- 
a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java
@@ -76,7 +76,7 @@ public class MaxentPrepAttachTest {
   public void testMaxentOnPrepAttachDataWithParams() throws IOException {
 
     Map<String, String> trainParams = new HashMap<>();
-    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
     trainParams.put(AbstractEventTrainer.DATA_INDEXER_PARAM,
         AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE);
     trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1));
@@ -91,7 +91,7 @@ public class MaxentPrepAttachTest {
   public void testMaxentOnPrepAttachDataWithParamsDefault() throws IOException 
{
 
     Map<String, String> trainParams = new HashMap<>();
-    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE);
+    trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE);
 
     EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null);
     MaxentModel model = 
trainer.train(PrepAttachDataUtil.createTrainingStream());

http://git-wip-us.apache.org/repos/asf/opennlp/blob/6ffdfbb8/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java
----------------------------------------------------------------------
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java
index 431ea96..850d9bc 100644
--- 
a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java
@@ -46,17 +46,18 @@ public class RealValueModelTest {
   @Test
   public void testRealValuedWeightsVsRepeatWeighting() throws IOException {
     GISModel realModel;
+    GISTrainer gisTrainer = new GISTrainer();
     try (RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
         
"src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt"))
 {
       testDataIndexer.index(rvfes1);
-      realModel = GIS.trainModel(100, testDataIndexer);
+      realModel = gisTrainer.trainModel(100, testDataIndexer);
     }
 
     GISModel repeatModel;
     try (FileEventStream rvfes2 = new FileEventStream(
         
"src/test/resources/data/opennlp/maxent/repeat-weighting-training-data.txt")) {
       testDataIndexer.index(rvfes2);
-      repeatModel = GIS.trainModel(100,testDataIndexer);
+      repeatModel = gisTrainer.trainModel(100,testDataIndexer);
     }
 
     String[] features2Classify = new String[] {"feature2","feature5"};

Reply via email to