This is an automated email from the ASF dual-hosted git repository.
jzemerick pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 3224ff51 OPENNLP-1495 Reduce code duplication in opennlp.tools.ml
package (#534)
3224ff51 is described below
commit 3224ff519b74b137a6325c7aa46653ac388f93da
Author: Martin Wiesner <[email protected]>
AuthorDate: Sat May 20 20:00:44 2023 +0200
OPENNLP-1495 Reduce code duplication in opennlp.tools.ml package (#534)
- reduces code duplication in PerceptronModelWriter, GISModelWriter, and
NaiveBayesModelWriter by introducing AbstractMLModelWriter
- simplifies related, existing test classes by using ParameterizedTest now
- optimizes PrepAttachDataUtil to cache already read ppa files avoiding IO
efforts during execution of the whole OpenNLP tools test suite
- reduces test execution times by ~5-10 %
- adjusts 'forbiddenapis' plugin's bundledSignature to more precisely cover
'jdk-deprecated-${java.version}' (for now: 11 !) which was by default 1.7
(outdated)
---
...ModelWriter.java => AbstractMLModelWriter.java} | 134 +++++----------------
.../opennlp/tools/ml/maxent/io/GISModelWriter.java | 107 ++++++----------
.../tools/ml/naivebayes/NaiveBayesModelWriter.java | 117 +++---------------
.../tools/ml/perceptron/PerceptronModelWriter.java | 82 ++-----------
.../java/opennlp/tools/ml/PrepAttachDataUtil.java | 35 +++---
.../ml/naivebayes/AbstractNaiveBayesTest.java | 50 ++++++++
.../ml/naivebayes/NaiveBayesCorrectnessTest.java | 109 ++++-------------
.../naivebayes/NaiveBayesModelReadWriteTest.java | 12 +-
.../ml/naivebayes/NaiveBayesPrepAttachTest.java | 25 ++--
.../NaiveBayesSerializedCorrectnessTest.java | 103 +++++-----------
pom.xml | 2 +-
11 files changed, 235 insertions(+), 541 deletions(-)
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java
similarity index 60%
copy from
opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
copy to opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java
index eb329029..ba11a370 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java
@@ -15,92 +15,60 @@
* limitations under the License.
*/
-package opennlp.tools.ml.naivebayes;
+package opennlp.tools.ml;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
-import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
import opennlp.tools.ml.model.Context;
-/**
- * The base class for {@link NaiveBayesModel} writers.
- * <p>
- * It provides the {@link #persist()} method which takes care of the structure
- * of a stored document, and requires an extending class to define precisely
- * how the data should be stored.
- *
- * @see NaiveBayesModel
- * @see AbstractModelWriter
- */
-public abstract class NaiveBayesModelWriter extends AbstractModelWriter {
+public abstract class AbstractMLModelWriter extends AbstractModelWriter {
- private static final Logger logger =
LoggerFactory.getLogger(NaiveBayesModelWriter.class);
+ private static final Logger logger =
LoggerFactory.getLogger(AbstractMLModelWriter.class);
protected Context[] PARAMS;
protected String[] OUTCOME_LABELS;
protected String[] PRED_LABELS;
- int numOutcomes;
-
- public NaiveBayesModelWriter(AbstractModel model) {
-
- Object[] data = model.getDataStructures();
- this.numOutcomes = model.getNumOutcomes();
- PARAMS = (Context[]) data[0];
-
- @SuppressWarnings("unchecked")
- Map<String, Context> pmap = (Map<String, Context>) data[1];
-
- OUTCOME_LABELS = (String[]) data[2];
- PARAMS = new Context[pmap.size()];
- PRED_LABELS = new String[pmap.size()];
-
- int i = 0;
- for (Map.Entry<String, Context> pred : pmap.entrySet()) {
- PRED_LABELS[i] = pred.getKey();
- PARAMS[i] = pred.getValue();
- i++;
- }
- }
+
+ protected int numOutcomes;
/**
- * Sorts and optimizes the model parameters.
+ * Sorts and optimizes the model parameters. Thereby, parameters with
+ * {@code 0} weight and predicates with no parameters are removed.
*
* @return A {@link ComparablePredicate[]}.
*/
- protected ComparablePredicate[] sortValues() {
-
- ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length];
-
- int numParams = 0;
- for (int pid = 0; pid < PARAMS.length; pid++) {
- int[] predkeys = PARAMS[pid].getOutcomes();
- // Arrays.sort(predkeys);
- int numActive = predkeys.length;
- double[] activeParams = PARAMS[pid].getParameters();
-
- numParams += numActive;
- /*
- * double[] activeParams = new double[numActive];
- *
- * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid =
- * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] =
- * PARAMS[pid].getParams(oid); id++; }
- */
- sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid],
- predkeys, activeParams);
- }
+ protected abstract ComparablePredicate[] sortValues();
- Arrays.sort(sortPreds);
- return sortPreds;
+ /**
+ * Computes outcome patterns via {@link ComparablePredicate[] predicates}.
+ *
+ * @return A {@link List} of {@link List<ComparablePredicate>} that represent
+ * the outcomes patterns.
+ */
+ protected List<List<ComparablePredicate>>
computeOutcomePatterns(ComparablePredicate[] sorted) {
+ ComparablePredicate cp = sorted[0];
+ List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
+ List<ComparablePredicate> newGroup = new ArrayList<>();
+ for (ComparablePredicate predicate : sorted) {
+ if (cp.compareTo(predicate) == 0) {
+ newGroup.add(predicate);
+ } else {
+ cp = predicate;
+ outcomePatterns.add(newGroup);
+ newGroup = new ArrayList<>();
+ newGroup.add(predicate);
+ }
+ }
+ outcomePatterns.add(newGroup);
+ logger.info("{} outcome patterns", outcomePatterns.size());
+ return outcomePatterns;
}
/**
@@ -129,48 +97,8 @@ public abstract class NaiveBayesModelWriter extends
AbstractModelWriter {
return outcomePatterns;
}
- /**
- * Computes outcome patterns via {@link ComparablePredicate[] predicates}.
- *
- * @return A {@link List} of {@link List<ComparablePredicate>} that represent
- * the outcomes patterns.
- */
- protected List<List<ComparablePredicate>>
computeOutcomePatterns(ComparablePredicate[] sorted) {
- ComparablePredicate cp = sorted[0];
- List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
- List<ComparablePredicate> newGroup = new ArrayList<>();
- for (ComparablePredicate predicate : sorted) {
- if (cp.compareTo(predicate) == 0) {
- newGroup.add(predicate);
- } else {
- cp = predicate;
- outcomePatterns.add(newGroup);
- newGroup = new ArrayList<>();
- newGroup.add(predicate);
- }
- }
- outcomePatterns.add(newGroup);
- logger.info("{} outcome patterns", outcomePatterns.size());
- return outcomePatterns;
- }
-
- /**
- * Writes the {@link AbstractModel perceptron model}, using the
- * {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link
#writeInt(int)}}
- * methods implemented by extending classes.
- *
- * <p>If you wish to create a {@link NaiveBayesModelWriter} which uses a
different
- * structure, it will be necessary to override the {@code #persist()} method
in
- * addition to implementing the {@code writeX(..)} methods.
- *
- * @throws IOException Thrown if IO errors occurred.
- */
@Override
public void persist() throws IOException {
-
- // the type of model (NaiveBayes)
- writeUTF("NaiveBayes");
-
// the mapping from outcomes to their integer indexes
writeInt(OUTCOME_LABELS.length);
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java
index 78ada10e..832e063f 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java
@@ -18,11 +18,11 @@
package opennlp.tools.ml.maxent.io;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import opennlp.tools.ml.AbstractMLModelWriter;
import opennlp.tools.ml.maxent.GISModel;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
@@ -35,15 +35,15 @@ import opennlp.tools.ml.model.Context;
* It provides the {@link #persist()} method which takes care of the structure
of a
* stored document, and requires an extending class to define precisely how
* the data should be stored.
+ *
+ * @see GISModel
+ * @see AbstractModelWriter
+ * @see AbstractMLModelWriter
*/
-public abstract class GISModelWriter extends AbstractModelWriter {
- protected Context[] PARAMS;
- protected String[] OUTCOME_LABELS;
- protected String[] PRED_LABELS;
+public abstract class GISModelWriter extends AbstractMLModelWriter {
/**
- * Initializes a {@link GISModelWriter} for a
- * {@link AbstractModel GIS model}.
+ * Initializes a {@link GISModelWriter} for a {@link AbstractModel GIS
model}.
*
* @param model The {@link AbstractModel GIS model} to be written.
*/
@@ -65,14 +65,45 @@ public abstract class GISModelWriter extends
AbstractModelWriter {
i++;
}
}
-
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ protected ComparablePredicate[] sortValues() {
+
+ ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length];
+
+ int numParams = 0;
+ for (int pid = 0; pid < PARAMS.length; pid++) {
+ int[] predkeys = PARAMS[pid].getOutcomes();
+ // Arrays.sort(predkeys);
+ int numActive = predkeys.length;
+ double[] activeParams = PARAMS[pid].getParameters();
+
+ numParams += numActive;
+ /*
+ * double[] activeParams = new double[numActive];
+ *
+ * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid =
+ * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] =
+ * PARAMS[pid].getParams(oid); id++; }
+ */
+ sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid],
+ predkeys, activeParams);
+ }
+
+ Arrays.sort(sortPreds);
+ return sortPreds;
+ }
+
/**
* Writes the {@link AbstractModel GIS model}, using the
* {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link
#writeInt(int)}}
* methods implemented by extending classes.
*
* <p>If you wish to create a {@link GISModelWriter} which uses a different
- * structure, it will be necessary to override the {@code #persist()} method
in
+ * structure, it will be necessary to override the {@link #persist()} method
in
* addition to implementing the {@code writeX(..)} methods.
*
* @throws IOException Thrown if IO errors occurred.
@@ -124,62 +155,4 @@ public abstract class GISModelWriter extends
AbstractModelWriter {
close();
}
-
- /**
- * Sorts and optimizes the model parameters.
- *
- * @return A {@link ComparablePredicate[]}.
- */
- protected ComparablePredicate[] sortValues() {
-
- ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length];
-
- int numParams = 0;
- for (int pid = 0; pid < PARAMS.length; pid++) {
- int[] predkeys = PARAMS[pid].getOutcomes();
- // Arrays.sort(predkeys);
- int numActive = predkeys.length;
- double[] activeParams = PARAMS[pid].getParameters();
-
- numParams += numActive;
- /*
- * double[] activeParams = new double[numActive];
- *
- * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid =
- * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] =
- * PARAMS[pid].getParams(oid); id++; }
- */
- sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid],
- predkeys, activeParams);
- }
-
- Arrays.sort(sortPreds);
- return sortPreds;
- }
-
- /**
- * Compresses outcome patterns.
- *
- * @return A {@link List} of {@link List<ComparablePredicate>} that represent
- * the remaining outcomes patterns.
- */
- protected List<List<ComparablePredicate>>
compressOutcomes(ComparablePredicate[] sorted) {
- List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
- if (sorted.length > 0) {
- ComparablePredicate cp = sorted[0];
- List<ComparablePredicate> newGroup = new ArrayList<>();
- for (ComparablePredicate comparablePredicate : sorted) {
- if (cp.compareTo(comparablePredicate) == 0) {
- newGroup.add(comparablePredicate);
- } else {
- cp = comparablePredicate;
- outcomePatterns.add(newGroup);
- newGroup = new ArrayList<>();
- newGroup.add(comparablePredicate);
- }
- }
- outcomePatterns.add(newGroup);
- }
- return outcomePatterns;
- }
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
index eb329029..a434ce4c 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java
@@ -18,14 +18,10 @@
package opennlp.tools.ml.naivebayes;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.Map;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
+import opennlp.tools.ml.AbstractMLModelWriter;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
@@ -40,18 +36,18 @@ import opennlp.tools.ml.model.Context;
*
* @see NaiveBayesModel
* @see AbstractModelWriter
+ * @see AbstractMLModelWriter
*/
-public abstract class NaiveBayesModelWriter extends AbstractModelWriter {
-
- private static final Logger logger =
LoggerFactory.getLogger(NaiveBayesModelWriter.class);
-
- protected Context[] PARAMS;
- protected String[] OUTCOME_LABELS;
- protected String[] PRED_LABELS;
- int numOutcomes;
+public abstract class NaiveBayesModelWriter extends AbstractMLModelWriter {
+ /**
+ * Initializes a {@link NaiveBayesModelWriter} for a
+ * {@link AbstractModel NaiveBayes model}.
+ *
+ * @param model The {@link AbstractModel NaiveBayes model} to be written.
+ */
public NaiveBayesModelWriter(AbstractModel model) {
-
+ super();
Object[] data = model.getDataStructures();
this.numOutcomes = model.getNumOutcomes();
PARAMS = (Context[]) data[0];
@@ -72,10 +68,9 @@ public abstract class NaiveBayesModelWriter extends
AbstractModelWriter {
}
/**
- * Sorts and optimizes the model parameters.
- *
- * @return A {@link ComparablePredicate[]}.
+ * {@inheritDoc}
*/
+ @Override
protected ComparablePredicate[] sortValues() {
ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length];
@@ -103,105 +98,21 @@ public abstract class NaiveBayesModelWriter extends
AbstractModelWriter {
return sortPreds;
}
- /**
- * Compresses outcome patterns.
- *
- * @return A {@link List} of {@link List<ComparablePredicate>} that represent
- * the remaining outcomes patterns.
- */
- protected List<List<ComparablePredicate>>
compressOutcomes(ComparablePredicate[] sorted) {
- List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
- if (sorted.length > 0) {
- ComparablePredicate cp = sorted[0];
- List<ComparablePredicate> newGroup = new ArrayList<>();
- for (ComparablePredicate comparablePredicate : sorted) {
- if (cp.compareTo(comparablePredicate) == 0) {
- newGroup.add(comparablePredicate);
- } else {
- cp = comparablePredicate;
- outcomePatterns.add(newGroup);
- newGroup = new ArrayList<>();
- newGroup.add(comparablePredicate);
- }
- }
- outcomePatterns.add(newGroup);
- }
- return outcomePatterns;
- }
-
- /**
- * Computes outcome patterns via {@link ComparablePredicate[] predicates}.
- *
- * @return A {@link List} of {@link List<ComparablePredicate>} that represent
- * the outcomes patterns.
- */
- protected List<List<ComparablePredicate>>
computeOutcomePatterns(ComparablePredicate[] sorted) {
- ComparablePredicate cp = sorted[0];
- List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
- List<ComparablePredicate> newGroup = new ArrayList<>();
- for (ComparablePredicate predicate : sorted) {
- if (cp.compareTo(predicate) == 0) {
- newGroup.add(predicate);
- } else {
- cp = predicate;
- outcomePatterns.add(newGroup);
- newGroup = new ArrayList<>();
- newGroup.add(predicate);
- }
- }
- outcomePatterns.add(newGroup);
- logger.info("{} outcome patterns", outcomePatterns.size());
- return outcomePatterns;
- }
-
/**
* Writes the {@link AbstractModel perceptron model}, using the
* {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link
#writeInt(int)}}
* methods implemented by extending classes.
*
* <p>If you wish to create a {@link NaiveBayesModelWriter} which uses a
different
- * structure, it will be necessary to override the {@code #persist()} method
in
+ * structure, it will be necessary to override the {@link #persist()} method
in
* addition to implementing the {@code writeX(..)} methods.
*
* @throws IOException Thrown if IO errors occurred.
*/
@Override
public void persist() throws IOException {
-
// the type of model (NaiveBayes)
writeUTF("NaiveBayes");
-
- // the mapping from outcomes to their integer indexes
- writeInt(OUTCOME_LABELS.length);
-
- for (String label : OUTCOME_LABELS) {
- writeUTF(label);
- }
-
- // the mapping from predicates to the outcomes they contributed to.
- // The sorting is done so that we actually can write this out more
- // compactly than as the entire list.
- ComparablePredicate[] sorted = sortValues();
- List<List<ComparablePredicate>> compressed =
computeOutcomePatterns(sorted);
-
- writeInt(compressed.size());
-
- for (List<ComparablePredicate> a : compressed) {
- writeUTF(a.size() + a.get(0).toString());
- }
-
- // the mapping from predicate names to their integer indexes
- writeInt(sorted.length);
-
- for (ComparablePredicate s : sorted) {
- writeUTF(s.name);
- }
-
- // write out the parameters
- for (ComparablePredicate comparablePredicate : sorted)
- for (int j = 0; j < comparablePredicate.params.length; j++)
- writeDouble(comparablePredicate.params[j]);
-
- close();
+ super.persist();
}
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java
index 87950c01..be233e46 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java
@@ -18,14 +18,13 @@
package opennlp.tools.ml.perceptron;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import opennlp.tools.ml.AbstractMLModelWriter;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.AbstractModelWriter;
import opennlp.tools.ml.model.ComparablePredicate;
@@ -40,14 +39,11 @@ import opennlp.tools.ml.model.Context;
*
* @see PerceptronModel
* @see AbstractModelWriter
+ * @see AbstractMLModelWriter
*/
-public abstract class PerceptronModelWriter extends AbstractModelWriter {
+public abstract class PerceptronModelWriter extends AbstractMLModelWriter {
private static final Logger logger =
LoggerFactory.getLogger(PerceptronModelWriter.class);
- protected Context[] PARAMS;
- protected String[] OUTCOME_LABELS;
- protected String[] PRED_LABELS;
- private final int numOutcomes;
/**
* Initializes a {@link PerceptronModelWriter} for a
@@ -56,9 +52,9 @@ public abstract class PerceptronModelWriter extends
AbstractModelWriter {
* @param model The {@link AbstractModel perceptron model} to be written.
*/
public PerceptronModelWriter(AbstractModel model) {
-
+ super();
Object[] data = model.getDataStructures();
- this.numOutcomes = model.getNumOutcomes();
+ numOutcomes = model.getNumOutcomes();
PARAMS = (Context[]) data[0];
@SuppressWarnings("unchecked")
@@ -77,11 +73,9 @@ public abstract class PerceptronModelWriter extends
AbstractModelWriter {
}
/**
- * Sorts and optimizes the model parameters. Thereby, parameters with
- * {@code 0} weight and predicates with no parameters are removed.
- *
- * @return A {@link ComparablePredicate[]}.
+ * {@inheritDoc}
*/
+ @Override
protected ComparablePredicate[] sortValues() {
ComparablePredicate[] sortPreds;
ComparablePredicate[] tmpPreds = new ComparablePredicate[PARAMS.length];
@@ -120,79 +114,21 @@ public abstract class PerceptronModelWriter extends
AbstractModelWriter {
return sortPreds;
}
- /**
- * Computes outcome patterns via {@link ComparablePredicate[] predicates}.
- *
- * @return A {@link List} of {@link List<ComparablePredicate>} that represent
- * the outcomes patterns.
- */
- protected List<List<ComparablePredicate>>
computeOutcomePatterns(ComparablePredicate[] sorted) {
- ComparablePredicate cp = sorted[0];
- List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>();
- List<ComparablePredicate> newGroup = new ArrayList<>();
- for (ComparablePredicate predicate : sorted) {
- if (cp.compareTo(predicate) == 0) {
- newGroup.add(predicate);
- } else {
- cp = predicate;
- outcomePatterns.add(newGroup);
- newGroup = new ArrayList<>();
- newGroup.add(predicate);
- }
- }
- outcomePatterns.add(newGroup);
- logger.info("{} outcome patterns", outcomePatterns.size());
- return outcomePatterns;
- }
-
/**
* Writes the {@link AbstractModel perceptron model}, using the
* {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link
#writeInt(int)}}
* methods implemented by extending classes.
*
* <p>If you wish to create a {@link PerceptronModelWriter} which uses a
different
- * structure, it will be necessary to override the {@code #persist()} method
in
+ * structure, it will be necessary to override the {@link #persist()} method
in
* addition to implementing the {@code writeX(..)} methods.
*
* @throws IOException Thrown if IO errors occurred.
*/
@Override
public void persist() throws IOException {
-
// the type of model (Perceptron)
writeUTF("Perceptron");
-
- // the mapping from outcomes to their integer indexes
- writeInt(OUTCOME_LABELS.length);
-
- for (String label : OUTCOME_LABELS) {
- writeUTF(label);
- }
-
- // the mapping from predicates to the outcomes they contributed to.
- // The sorting is done so that we actually can write this out more
- // compactly than as the entire list.
- ComparablePredicate[] sorted = sortValues();
- List<List<ComparablePredicate>> compressed =
computeOutcomePatterns(sorted);
-
- writeInt(compressed.size());
-
- for (List<ComparablePredicate> a : compressed) {
- writeUTF(a.size() + a.get(0).toString());
- }
-
- // the mapping from predicate names to their integer indexes
- writeInt(sorted.length);
-
- for (ComparablePredicate s : sorted) {
- writeUTF(s.name);
- }
-
- // write out the parameters
- for (ComparablePredicate comparablePredicate : sorted)
- for (int j = 0; j < comparablePredicate.params.length; j++)
- writeDouble(comparablePredicate.params[j]);
-
- close();
+ super.persist();
}
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java
index 4b5b5573..bd4de105 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java
@@ -23,36 +23,39 @@ import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.junit.jupiter.api.Assertions;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
-import opennlp.tools.ml.perceptron.PerceptronPrepAttachTest;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ObjectStreamUtils;
public class PrepAttachDataUtil {
- private static List<Event> readPpaFile(String filename) throws IOException {
+ /* Caches ppa files as List<Event> via their name (key) */
+ private static final Map<String, List<Event>> PPA_FILE_EVENTS = new
HashMap<>();
- List<Event> events = new ArrayList<>();
-
- try (InputStream in =
PerceptronPrepAttachTest.class.getResourceAsStream("/data/ppa/" +
- filename)) {
- BufferedReader reader = new BufferedReader(new InputStreamReader(in,
StandardCharsets.UTF_8));
- String line;
- while ((line = reader.readLine()) != null) {
- String[] items = line.split("\\s+");
- String label = items[5];
- String[] context = {"verb=" + items[1], "noun=" + items[2],
- "prep=" + items[3], "prep_obj=" + items[4]};
- events.add(new Event(label, context));
+ private static List<Event> readPpaFile(String filename) throws IOException {
+ if (!PPA_FILE_EVENTS.containsKey(filename)) {
+ List<Event> events = new ArrayList<>();
+ try (InputStream in =
PrepAttachDataUtil.class.getResourceAsStream("/data/ppa/" + filename);
+ BufferedReader reader = new BufferedReader(new
InputStreamReader(in, StandardCharsets.UTF_8))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ String[] items = line.split("\\s+");
+ String label = items[5];
+ String[] context = {"verb=" + items[1], "noun=" + items[2],
+ "prep=" + items[3], "prep_obj=" + items[4]};
+ events.add(new Event(label, context));
+ }
+ PPA_FILE_EVENTS.put(filename, events);
}
}
-
- return events;
+ return PPA_FILE_EVENTS.get(filename);
}
public static ObjectStream<Event> createTrainingStream() throws IOException {
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java
new file mode 100644
index 00000000..ce275859
--- /dev/null
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.naivebayes;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import opennlp.tools.ml.model.Event;
+import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.ObjectStreamUtils;
+
+public class AbstractNaiveBayesTest {
+
+ protected ObjectStream<Event> createTrainingStream() {
+ List<Event> trainingEvents = new ArrayList<>();
+
+ String label1 = "politics";
+ String[] context1 = {"bow=the", "bow=united", "bow=nations"};
+ trainingEvents.add(new Event(label1, context1));
+
+ String label2 = "politics";
+ String[] context2 = {"bow=the", "bow=united", "bow=states", "bow=and"};
+ trainingEvents.add(new Event(label2, context2));
+
+ String label3 = "sports";
+ String[] context3 = {"bow=manchester", "bow=united"};
+ trainingEvents.add(new Event(label3, context3));
+
+ String label4 = "sports";
+ String[] context4 = {"bow=manchester", "bow=and", "bow=barca"};
+ trainingEvents.add(new Event(label4, context4));
+
+ return ObjectStreamUtils.createObjectStream(trainingEvents);
+ }
+}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java
index a22861ab..0aacb370 100644
---
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java
@@ -18,13 +18,14 @@
package opennlp.tools.ml.naivebayes;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.HashMap;
-import java.util.List;
+import java.util.stream.Stream;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
import opennlp.tools.ml.AbstractTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
@@ -32,87 +33,45 @@ import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.TwoPassDataIndexer;
-import opennlp.tools.util.ObjectStream;
-import opennlp.tools.util.ObjectStreamUtils;
import opennlp.tools.util.TrainingParameters;
/**
* Test for naive bayes classification correctness without smoothing
*/
-public class NaiveBayesCorrectnessTest {
+public class NaiveBayesCorrectnessTest extends AbstractNaiveBayesTest {
private DataIndexer testDataIndexer;
@BeforeEach
- void initIndexer() {
+ void initIndexer() throws IOException {
TrainingParameters trainingParameters = new TrainingParameters();
trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1);
trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
testDataIndexer = new TwoPassDataIndexer();
testDataIndexer.init(trainingParameters, new HashMap<>());
- }
-
- @Test
- void testNaiveBayes1() throws IOException {
-
testDataIndexer.index(createTrainingStream());
- NaiveBayesModel model =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- String label = "politics";
- String[] context = {"bow=united", "bow=nations"};
- Event event = new Event(label, context);
-
- // testModel(model, event, 1.0); // Expected value without smoothing
- testModel(model, event, 0.9681650180264167); // Expected value with
smoothing
-
}
- @Test
- void testNaiveBayes2() throws IOException {
-
- testDataIndexer.index(createTrainingStream());
- NaiveBayesModel model =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- String label = "sports";
- String[] context = {"bow=manchester", "bow=united"};
+ @ParameterizedTest
+ @MethodSource("provideLabelsWithContextAndProb")
+ void testNaiveBayes(String label, String[] context, double expectedProb) {
+ NaiveBayesModel model = (NaiveBayesModel) new
NaiveBayesTrainer().trainModel(testDataIndexer);
Event event = new Event(label, context);
-
- // testModel(model, event, 1.0); // Expected value without smoothing
- testModel(model, event, 0.9658833555831029); // Expected value with
smoothing
-
+
+ testModel(model, event, expectedProb); // Expected value with smoothing
}
- @Test
- void testNaiveBayes3() throws IOException {
-
- testDataIndexer.index(createTrainingStream());
- NaiveBayesModel model =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- String label = "politics";
- String[] context = {"bow=united"};
- Event event = new Event(label, context);
-
- //testModel(model, event, 2.0/3.0); // Expected value without smoothing
- testModel(model, event, 0.6655036407766989); // Expected value with
smoothing
-
- }
-
- @Test
- void testNaiveBayes4() throws IOException {
-
- testDataIndexer.index(createTrainingStream());
- NaiveBayesModel model =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- String label = "politics";
- String[] context = {};
- Event event = new Event(label, context);
-
- testModel(model, event, 7.0 / 12.0);
-
+ /*
+ * Produces a stream of <label|context> pairs for parameterized unit tests.
+ */
+ private static Stream<Arguments> provideLabelsWithContextAndProb() {
+ return Stream.of(
+ // Example 1:
+ Arguments.of("politics" , new String[] {"bow=united",
"bow=nations"}, 0.9681650180264167),
+ Arguments.of("sports", new String[] {"bow=manchester",
"bow=united"}, 0.9658833555831029),
+ Arguments.of("politics", new String[] {"bow=united"},
0.6655036407766989),
+ Arguments.of("politics", new String[] {}, 7.0 / 12.0)
+ );
}
private void testModel(MaxentModel model, Event event, double
higher_probability) {
@@ -134,26 +93,4 @@ public class NaiveBayesCorrectnessTest {
}
}
- public static ObjectStream<Event> createTrainingStream() {
- List<Event> trainingEvents = new ArrayList<>();
-
- String label1 = "politics";
- String[] context1 = {"bow=the", "bow=united", "bow=nations"};
- trainingEvents.add(new Event(label1, context1));
-
- String label2 = "politics";
- String[] context2 = {"bow=the", "bow=united", "bow=states", "bow=and"};
- trainingEvents.add(new Event(label2, context2));
-
- String label3 = "sports";
- String[] context3 = {"bow=manchester", "bow=united"};
- trainingEvents.add(new Event(label3, context3));
-
- String label4 = "sports";
- String[] context4 = {"bow=manchester", "bow=and", "bow=barca"};
- trainingEvents.add(new Event(label4, context4));
-
- return ObjectStreamUtils.createObjectStream(trainingEvents);
- }
-
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java
index 64de26ab..6c135c9d 100644
---
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java
@@ -18,6 +18,7 @@
package opennlp.tools.ml.naivebayes;
import java.io.File;
+import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
@@ -34,24 +35,24 @@ import opennlp.tools.ml.model.TwoPassDataIndexer;
import opennlp.tools.util.TrainingParameters;
/**
- * Tests for persisting and reading naive bayes models
+ * Tests for persisting and reading naive bayes models.
*/
-public class NaiveBayesModelReadWriteTest {
+public class NaiveBayesModelReadWriteTest extends AbstractNaiveBayesTest {
private DataIndexer testDataIndexer;
@BeforeEach
- void initIndexer() {
+ void initIndexer() throws IOException {
TrainingParameters trainingParameters = new TrainingParameters();
trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1);
trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
testDataIndexer = new TwoPassDataIndexer();
testDataIndexer.init(trainingParameters, new HashMap<>());
+ testDataIndexer.index(createTrainingStream());
}
@Test
- void testBinaryModelPersistence() throws Exception {
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
+ void testBinaryModelPersistence() throws IOException {
NaiveBayesModel model = (NaiveBayesModel) new
NaiveBayesTrainer().trainModel(testDataIndexer);
Path tempFile = Files.createTempFile("bnb-", ".bin");
File file = tempFile.toFile();
@@ -69,7 +70,6 @@ public class NaiveBayesModelReadWriteTest {
@Test
void testTextModelPersistence() throws Exception {
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
NaiveBayesModel model = (NaiveBayesModel) new
NaiveBayesTrainer().trainModel(testDataIndexer);
Path tempFile = Files.createTempFile("ptnb-", ".txt");
File file = tempFile.toFile();
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java
index 904e7549..06572abb 100644
---
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java
@@ -30,8 +30,10 @@ import opennlp.tools.ml.PrepAttachDataUtil;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.TwoPassDataIndexer;
+import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
/**
@@ -39,20 +41,23 @@ import opennlp.tools.util.TrainingParameters;
*/
public class NaiveBayesPrepAttachTest {
- private DataIndexer testDataIndexer;
+ private ObjectStream<Event> trainingStream;
@BeforeEach
- void initIndexer() {
- TrainingParameters trainingParameters = new TrainingParameters();
- trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1);
- trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
- testDataIndexer = new TwoPassDataIndexer();
- testDataIndexer.init(trainingParameters, new HashMap<>());
+ void initIndexer() throws IOException {
+ trainingStream = PrepAttachDataUtil.createTrainingStream();
+ Assertions.assertNotNull(trainingStream);
}
@Test
void testNaiveBayesOnPrepAttachData() throws IOException {
- testDataIndexer.index(PrepAttachDataUtil.createTrainingStream());
+ TrainingParameters trainingParameters = new TrainingParameters();
+ trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1);
+ trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
+ DataIndexer testDataIndexer = new TwoPassDataIndexer();
+ testDataIndexer.init(trainingParameters, new HashMap<>());
+ testDataIndexer.index(trainingStream);
+
MaxentModel model = new NaiveBayesTrainer().trainModel(testDataIndexer);
Assertions.assertTrue(model instanceof NaiveBayesModel);
PrepAttachDataUtil.testModel(model, 0.7897994553107205);
@@ -65,7 +70,7 @@ public class NaiveBayesPrepAttachTest {
trainParams.put(AbstractTrainer.CUTOFF_PARAM, 1);
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null);
- MaxentModel model =
trainer.train(PrepAttachDataUtil.createTrainingStream());
+ MaxentModel model = trainer.train(trainingStream);
Assertions.assertTrue(model instanceof NaiveBayesModel);
PrepAttachDataUtil.testModel(model, 0.7897994553107205);
}
@@ -77,7 +82,7 @@ public class NaiveBayesPrepAttachTest {
trainParams.put(AbstractTrainer.CUTOFF_PARAM, 5);
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null);
- MaxentModel model =
trainer.train(PrepAttachDataUtil.createTrainingStream());
+ MaxentModel model = trainer.train(trainingStream);
Assertions.assertTrue(model instanceof NaiveBayesModel);
PrepAttachDataUtil.testModel(model, 0.7945035899975241);
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java
index 31f37426..e0abd823 100644
---
a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java
+++
b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java
@@ -26,10 +26,14 @@ import java.io.StringWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
+import java.util.stream.Stream;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
import opennlp.tools.ml.AbstractTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
@@ -39,101 +43,51 @@ import opennlp.tools.ml.model.TwoPassDataIndexer;
import opennlp.tools.util.TrainingParameters;
/**
- * Test for naive bayes classification correctness without smoothing
+ * Test for naive bayes classification correctness without smoothing.
*/
-public class NaiveBayesSerializedCorrectnessTest {
+public class NaiveBayesSerializedCorrectnessTest extends
AbstractNaiveBayesTest {
private DataIndexer testDataIndexer;
@BeforeEach
- void initIndexer() {
+ void initIndexer() throws IOException {
TrainingParameters trainingParameters = new TrainingParameters();
trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1);
trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
testDataIndexer = new TwoPassDataIndexer();
testDataIndexer.init(trainingParameters, new HashMap<>());
+ testDataIndexer.index(createTrainingStream());
}
- @Test
- void testNaiveBayes1() throws IOException {
-
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
- NaiveBayesModel model1 =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- NaiveBayesModel model2 = persistedModel(model1);
-
- String label = "politics";
- String[] context = {"bow=united", "bow=nations"};
- Event event = new Event(label, context);
-
- testModelOutcome(model1, model2, event);
-
- }
-
- @Test
- void testNaiveBayes2() throws IOException {
-
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
- NaiveBayesModel model1 =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- NaiveBayesModel model2 = persistedModel(model1);
-
- String label = "sports";
- String[] context = {"bow=manchester", "bow=united"};
- Event event = new Event(label, context);
-
- testModelOutcome(model1, model2, event);
-
- }
-
- @Test
- void testNaiveBayes3() throws IOException {
-
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
- NaiveBayesModel model1 =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
+ @ParameterizedTest
+ @MethodSource("provideLabelsWithContext")
+ void testNaiveBayes(String label, String[] context) throws IOException {
+ NaiveBayesModel model1 = (NaiveBayesModel) new
NaiveBayesTrainer().trainModel(testDataIndexer);
NaiveBayesModel model2 = persistedModel(model1);
-
- String label = "politics";
- String[] context = {"bow=united"};
Event event = new Event(label, context);
-
testModelOutcome(model1, model2, event);
-
}
- @Test
- void testNaiveBayes4() throws IOException {
-
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
- NaiveBayesModel model1 =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
- NaiveBayesModel model2 = persistedModel(model1);
-
- String label = "politics";
- String[] context = {};
- Event event = new Event(label, context);
-
- testModelOutcome(model1, model2, event);
-
+ /*
+ * Produces a stream of <label|context> pairs for parameterized unit tests.
+ */
+ private static Stream<Arguments> provideLabelsWithContext() {
+ return Stream.of(
+ // Example 1:
+ Arguments.of("politics" , new String[] {"bow=united",
"bow=nations"}),
+ Arguments.of("sports", new String[] {"bow=manchester",
"bow=united"}),
+ Arguments.of("politics", new String[] {"bow=united"}),
+ Arguments.of("politics", new String[] {})
+ );
}
-
@Test
void testPlainTextModel() throws IOException {
- testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream());
- NaiveBayesModel model1 =
- (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer);
-
+ NaiveBayesModel model1 = (NaiveBayesModel) new
NaiveBayesTrainer().trainModel(testDataIndexer);
StringWriter sw1 = new StringWriter();
- NaiveBayesModelWriter modelWriter =
- new PlainTextNaiveBayesModelWriter(model1, new BufferedWriter(sw1));
+ NaiveBayesModelWriter modelWriter = new
PlainTextNaiveBayesModelWriter(model1, new BufferedWriter(sw1));
modelWriter.persist();
NaiveBayesModelReader reader =
@@ -147,10 +101,9 @@ public class NaiveBayesSerializedCorrectnessTest {
modelWriter.persist();
Assertions.assertEquals(sw1.toString(), sw2.toString());
-
}
- protected static NaiveBayesModel persistedModel(NaiveBayesModel model)
throws IOException {
+ private static NaiveBayesModel persistedModel(NaiveBayesModel model) throws
IOException {
Path tempFilePath = Files.createTempFile("ptnb-", ".bin");
File file = tempFilePath.toFile();
try {
@@ -164,17 +117,15 @@ public class NaiveBayesSerializedCorrectnessTest {
}
}
- protected static void testModelOutcome(NaiveBayesModel model1,
NaiveBayesModel model2, Event event) {
+ private static void testModelOutcome(NaiveBayesModel model1, NaiveBayesModel
model2, Event event) {
String[] labels1 = extractLabels(model1);
String[] labels2 = extractLabels(model2);
Assertions.assertArrayEquals(labels1, labels2);
-
double[] outcomes1 = model1.eval(event.getContext());
double[] outcomes2 = model2.eval(event.getContext());
Assertions.assertArrayEquals(outcomes1, outcomes2, 0.000000000001);
-
}
private static String[] extractLabels(NaiveBayesModel model) {
diff --git a/pom.xml b/pom.xml
index 3c95fa74..828d06a5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -326,7 +326,7 @@
<configuration>
<failOnUnsupportedJava>false</failOnUnsupportedJava>
<bundledSignatures>
-
<bundledSignature>jdk-deprecated</bundledSignature>
+
<bundledSignature>jdk-deprecated-${java.version}</bundledSignature>
<bundledSignature>jdk-non-portable</bundledSignature>
</bundledSignatures>
</configuration>