OPENNLP-1044: Add validate() which checks validity of parameters in the process of the framework. This closes apache/opennlp#192
Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/ca9a1d94 Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/ca9a1d94 Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/ca9a1d94 Branch: refs/heads/LangDetect Commit: ca9a1d943d4cde23fe36d0c557ddb4110bad0260 Parents: 5f96aa3 Author: koji <k...@apache.org> Authored: Mon May 8 11:00:18 2017 +0900 Committer: koji <k...@apache.org> Committed: Mon May 8 11:00:18 2017 +0900 ---------------------------------------------------------------------- .../ml/AbstractEventModelSequenceTrainer.java | 5 +-- .../opennlp/tools/ml/AbstractEventTrainer.java | 17 +++++----- .../tools/ml/AbstractSequenceTrainer.java | 5 +-- .../java/opennlp/tools/ml/AbstractTrainer.java | 26 ++++++++++++--- .../tools/ml/maxent/quasinewton/QNTrainer.java | 34 +++++++++++++------- .../tools/ml/naivebayes/NaiveBayesTrainer.java | 4 --- .../tools/ml/perceptron/PerceptronTrainer.java | 20 +++++++++--- .../SimplePerceptronSequenceTrainer.java | 26 +++++++++++---- 8 files changed, 88 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java index fdcb4b6..362a0d6 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java @@ -32,10 +32,7 @@ public abstract class AbstractEventModelSequenceTrainer extends AbstractTrainer throws IOException; public final MaxentModel train(SequenceStream events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); MaxentModel model = doTrain(events); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java index 330307a..dc75ffe 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java @@ -42,7 +42,13 @@ public abstract class AbstractEventTrainer extends AbstractTrainer implements Ev public AbstractEventTrainer(TrainingParameters parameters) { super(parameters); } - + + @Override + public void validate() { + super.validate(); + } + + @Deprecated @Override public boolean isValid() { return super.isValid(); @@ -66,9 +72,7 @@ public abstract class AbstractEventTrainer extends AbstractTrainer implements Ev public abstract MaxentModel doTrain(DataIndexer indexer) throws IOException; public final MaxentModel train(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); if (indexer.getOutcomeLabels().length <= 1) { throw new InsufficientTrainingDataException("Training data must contain more than one outcome"); @@ -80,10 +84,7 @@ public abstract class AbstractEventTrainer extends AbstractTrainer implements Ev } public final MaxentModel train(ObjectStream<Event> events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); HashSumEventStream hses = new HashSumEventStream(events); DataIndexer indexer = getDataIndexer(hses); http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java index 2d48624..19ecc4b 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java @@ -32,10 +32,7 @@ public abstract class AbstractSequenceTrainer extends AbstractTrainer implements throws IOException; public final SequenceClassificationModel<String> train(SequenceStream events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); SequenceClassificationModel<String> model = doTrain(events); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, SequenceTrainer.SEQUENCE_VALUE); http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java index 070b96c..32c5df6 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java @@ -74,20 +74,36 @@ public abstract class AbstractTrainer { return trainingParameters.getIntParameter(ITERATIONS_PARAM, ITERATIONS_DEFAULT); } - public boolean isValid() { - + /** + * Check parameters. If subclass overrides this, it should call super.validate(); + * + * @throws java.lang.IllegalArgumentException + */ + public void validate() { // TODO: Need to validate all parameters correctly ... error prone?! - // should validate if algorithm is set? What about the Parser? try { trainingParameters.getIntParameter(CUTOFF_PARAM, CUTOFF_DEFAULT); trainingParameters.getIntParameter(ITERATIONS_PARAM, ITERATIONS_DEFAULT); } catch (NumberFormatException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * @deprecated Use {@link #validate()} instead. + * @return + */ + @Deprecated + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { return false; } - - return true; } /** http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java index 7a1a74f..daa90a4 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java @@ -115,42 +115,52 @@ public class QNTrainer extends AbstractEventTrainer { init(new TrainingParameters(trainParams),reportMap); } - public boolean isValid() { - - if (!super.isValid()) { - return false; - } + @Override + public void validate() { + super.validate(); String algorithmName = getAlgorithm(); if (algorithmName != null && !(MAXENT_QN_VALUE.equals(algorithmName))) { - return false; + throw new IllegalArgumentException("algorithmName must be MAXENT_QN"); } // Number of Hessian updates to remember if (m < 0) { - return false; + throw new IllegalArgumentException( + "Number of Hessian updates to remember must be >= 0"); } // Maximum number of function evaluations if (maxFctEval < 0) { - return false; + throw new IllegalArgumentException( + "Maximum number of function evaluations must be >= 0"); } // Number of threads must be >= 1 if (threads < 1) { - return false; + throw new IllegalArgumentException("Number of threads must be >= 1"); } // Regularization costs must be >= 0 if (l1Cost < 0) { - return false; + throw new IllegalArgumentException("Regularization costs must be >= 0"); } if (l2Cost < 0) { - return false; + throw new IllegalArgumentException("Regularization costs must be >= 0"); } + } - return true; + @Deprecated + @Override + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { + return false; + } } public boolean isSortAndMerge() { http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java index 629c222..69ef44e 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java @@ -102,10 +102,6 @@ public class NaiveBayesTrainer extends AbstractEventTrainer { } public AbstractModel doTrain(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } - return this.trainModel(indexer); } http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java index 129c576..b73eaca 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java @@ -84,7 +84,21 @@ public class PerceptronTrainer extends AbstractEventTrainer { public PerceptronTrainer(TrainingParameters parameters) { super(parameters); } - + + @Override + public void validate() { + super.validate(); + + String algorithmName = getAlgorithm(); + if (algorithmName != null) { + if (!PERCEPTRON_VALUE.equals(algorithmName)) { + throw new IllegalArgumentException("algorithmName must be PERCEPTRON"); + } + } + } + + @Deprecated + @Override public boolean isValid() { if (!super.isValid()) { return false; @@ -104,10 +118,6 @@ public class PerceptronTrainer extends AbstractEventTrainer { } public AbstractModel doTrain(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } - int iterations = getIterations(); int cutoff = getCutoff(); http://git-wip-us.apache.org/repos/asf/opennlp/blob/ca9a1d94/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java index 5fc4bbe..a9ac516 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java @@ -83,16 +83,28 @@ public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceT public SimplePerceptronSequenceTrainer() { } - public boolean isValid() { - - if (!super.isValid()) { - return false; - } + @Override + public void validate() { + super.validate(); String algorithmName = getAlgorithm(); + if (algorithmName != null) { + if (!PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName)) { + throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE"); + } + } + } - return !(algorithmName != null - && !(PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))); + @Deprecated + @Override + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { + return false; + } } public AbstractModel doTrain(SequenceStream events) throws IOException {