This is an automated email from the ASF dual-hosted git repository. mawiesne pushed a commit to branch OPENNLP-1495_Reduce_code_duplication_in_opennlp.tools.ml_package in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit 684ca75e196a161398ca9dded879c6a94ea34757 Author: Martin Wiesner <[email protected]> AuthorDate: Fri May 19 19:32:13 2023 +0200 OPENNLP-1495 Reduce code duplication in opennlp.tools.ml package - reduces code duplication in PerceptronModelWriter, GISModelWriter, and NaiveBayesModelWriter by introducing AbstractMLModelWriter - simplifies related, existing test classes by using ParameterizedTest now - optimizes PrepAttachDataUtil to cache already read ppa files avoiding IO efforts during execution of the whole OpenNLP tools test suite - reduces test execution times by ~5-10 % - adjusts 'forbiddenapis' plugin's bundledSignature to more precisely cover 'jdk-deprecated-${java.version}' (for now: 11 !) which was by default 1.7 (outdated) --- ...ModelWriter.java => AbstractMLModelWriter.java} | 134 +++++---------------- .../opennlp/tools/ml/maxent/io/GISModelWriter.java | 107 ++++++---------- .../tools/ml/naivebayes/NaiveBayesModelWriter.java | 117 +++--------------- .../tools/ml/perceptron/PerceptronModelWriter.java | 82 ++----------- .../java/opennlp/tools/ml/PrepAttachDataUtil.java | 35 +++--- .../ml/naivebayes/AbstractNaiveBayesTest.java | 50 ++++++++ .../ml/naivebayes/NaiveBayesCorrectnessTest.java | 109 ++++------------- .../naivebayes/NaiveBayesModelReadWriteTest.java | 12 +- .../ml/naivebayes/NaiveBayesPrepAttachTest.java | 25 ++-- .../NaiveBayesSerializedCorrectnessTest.java | 103 +++++----------- pom.xml | 2 +- 11 files changed, 235 insertions(+), 541 deletions(-) diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java similarity index 60% copy from opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java copy to opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java index eb329029..ba11a370 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractMLModelWriter.java @@ -15,92 +15,60 @@ * limitations under the License. */ -package opennlp.tools.ml.naivebayes; +package opennlp.tools.ml; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.AbstractModelWriter; import opennlp.tools.ml.model.ComparablePredicate; import opennlp.tools.ml.model.Context; -/** - * The base class for {@link NaiveBayesModel} writers. - * <p> - * It provides the {@link #persist()} method which takes care of the structure - * of a stored document, and requires an extending class to define precisely - * how the data should be stored. - * - * @see NaiveBayesModel - * @see AbstractModelWriter - */ -public abstract class NaiveBayesModelWriter extends AbstractModelWriter { +public abstract class AbstractMLModelWriter extends AbstractModelWriter { - private static final Logger logger = LoggerFactory.getLogger(NaiveBayesModelWriter.class); + private static final Logger logger = LoggerFactory.getLogger(AbstractMLModelWriter.class); protected Context[] PARAMS; protected String[] OUTCOME_LABELS; protected String[] PRED_LABELS; - int numOutcomes; - - public NaiveBayesModelWriter(AbstractModel model) { - - Object[] data = model.getDataStructures(); - this.numOutcomes = model.getNumOutcomes(); - PARAMS = (Context[]) data[0]; - - @SuppressWarnings("unchecked") - Map<String, Context> pmap = (Map<String, Context>) data[1]; - - OUTCOME_LABELS = (String[]) data[2]; - PARAMS = new Context[pmap.size()]; - PRED_LABELS = new String[pmap.size()]; - - int i = 0; - for (Map.Entry<String, Context> pred : pmap.entrySet()) { - PRED_LABELS[i] = pred.getKey(); - PARAMS[i] = pred.getValue(); - i++; - } - } + + protected int numOutcomes; /** - * Sorts and optimizes the model parameters. + * Sorts and optimizes the model parameters. Thereby, parameters with + * {@code 0} weight and predicates with no parameters are removed. * * @return A {@link ComparablePredicate[]}. */ - protected ComparablePredicate[] sortValues() { - - ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length]; - - int numParams = 0; - for (int pid = 0; pid < PARAMS.length; pid++) { - int[] predkeys = PARAMS[pid].getOutcomes(); - // Arrays.sort(predkeys); - int numActive = predkeys.length; - double[] activeParams = PARAMS[pid].getParameters(); - - numParams += numActive; - /* - * double[] activeParams = new double[numActive]; - * - * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid = - * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] = - * PARAMS[pid].getParams(oid); id++; } - */ - sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid], - predkeys, activeParams); - } + protected abstract ComparablePredicate[] sortValues(); - Arrays.sort(sortPreds); - return sortPreds; + /** + * Computes outcome patterns via {@link ComparablePredicate[] predicates}. + * + * @return A {@link List} of {@link List<ComparablePredicate>} that represent + * the outcomes patterns. + */ + protected List<List<ComparablePredicate>> computeOutcomePatterns(ComparablePredicate[] sorted) { + ComparablePredicate cp = sorted[0]; + List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); + List<ComparablePredicate> newGroup = new ArrayList<>(); + for (ComparablePredicate predicate : sorted) { + if (cp.compareTo(predicate) == 0) { + newGroup.add(predicate); + } else { + cp = predicate; + outcomePatterns.add(newGroup); + newGroup = new ArrayList<>(); + newGroup.add(predicate); + } + } + outcomePatterns.add(newGroup); + logger.info("{} outcome patterns", outcomePatterns.size()); + return outcomePatterns; } /** @@ -129,48 +97,8 @@ public abstract class NaiveBayesModelWriter extends AbstractModelWriter { return outcomePatterns; } - /** - * Computes outcome patterns via {@link ComparablePredicate[] predicates}. - * - * @return A {@link List} of {@link List<ComparablePredicate>} that represent - * the outcomes patterns. - */ - protected List<List<ComparablePredicate>> computeOutcomePatterns(ComparablePredicate[] sorted) { - ComparablePredicate cp = sorted[0]; - List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); - List<ComparablePredicate> newGroup = new ArrayList<>(); - for (ComparablePredicate predicate : sorted) { - if (cp.compareTo(predicate) == 0) { - newGroup.add(predicate); - } else { - cp = predicate; - outcomePatterns.add(newGroup); - newGroup = new ArrayList<>(); - newGroup.add(predicate); - } - } - outcomePatterns.add(newGroup); - logger.info("{} outcome patterns", outcomePatterns.size()); - return outcomePatterns; - } - - /** - * Writes the {@link AbstractModel perceptron model}, using the - * {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link #writeInt(int)}} - * methods implemented by extending classes. - * - * <p>If you wish to create a {@link NaiveBayesModelWriter} which uses a different - * structure, it will be necessary to override the {@code #persist()} method in - * addition to implementing the {@code writeX(..)} methods. - * - * @throws IOException Thrown if IO errors occurred. - */ @Override public void persist() throws IOException { - - // the type of model (NaiveBayes) - writeUTF("NaiveBayes"); - // the mapping from outcomes to their integer indexes writeInt(OUTCOME_LABELS.length); diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java index 78ada10e..832e063f 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/io/GISModelWriter.java @@ -18,11 +18,11 @@ package opennlp.tools.ml.maxent.io; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import opennlp.tools.ml.AbstractMLModelWriter; import opennlp.tools.ml.maxent.GISModel; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.AbstractModelWriter; @@ -35,15 +35,15 @@ import opennlp.tools.ml.model.Context; * It provides the {@link #persist()} method which takes care of the structure of a * stored document, and requires an extending class to define precisely how * the data should be stored. + * + * @see GISModel + * @see AbstractModelWriter + * @see AbstractMLModelWriter */ -public abstract class GISModelWriter extends AbstractModelWriter { - protected Context[] PARAMS; - protected String[] OUTCOME_LABELS; - protected String[] PRED_LABELS; +public abstract class GISModelWriter extends AbstractMLModelWriter { /** - * Initializes a {@link GISModelWriter} for a - * {@link AbstractModel GIS model}. + * Initializes a {@link GISModelWriter} for a {@link AbstractModel GIS model}. * * @param model The {@link AbstractModel GIS model} to be written. */ @@ -65,14 +65,45 @@ public abstract class GISModelWriter extends AbstractModelWriter { i++; } } - + + /** + * {@inheritDoc} + */ + @Override + protected ComparablePredicate[] sortValues() { + + ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length]; + + int numParams = 0; + for (int pid = 0; pid < PARAMS.length; pid++) { + int[] predkeys = PARAMS[pid].getOutcomes(); + // Arrays.sort(predkeys); + int numActive = predkeys.length; + double[] activeParams = PARAMS[pid].getParameters(); + + numParams += numActive; + /* + * double[] activeParams = new double[numActive]; + * + * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid = + * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] = + * PARAMS[pid].getParams(oid); id++; } + */ + sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid], + predkeys, activeParams); + } + + Arrays.sort(sortPreds); + return sortPreds; + } + /** * Writes the {@link AbstractModel GIS model}, using the * {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link #writeInt(int)}} * methods implemented by extending classes. * * <p>If you wish to create a {@link GISModelWriter} which uses a different - * structure, it will be necessary to override the {@code #persist()} method in + * structure, it will be necessary to override the {@link #persist()} method in * addition to implementing the {@code writeX(..)} methods. * * @throws IOException Thrown if IO errors occurred. @@ -124,62 +155,4 @@ public abstract class GISModelWriter extends AbstractModelWriter { close(); } - - /** - * Sorts and optimizes the model parameters. - * - * @return A {@link ComparablePredicate[]}. - */ - protected ComparablePredicate[] sortValues() { - - ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length]; - - int numParams = 0; - for (int pid = 0; pid < PARAMS.length; pid++) { - int[] predkeys = PARAMS[pid].getOutcomes(); - // Arrays.sort(predkeys); - int numActive = predkeys.length; - double[] activeParams = PARAMS[pid].getParameters(); - - numParams += numActive; - /* - * double[] activeParams = new double[numActive]; - * - * int id = 0; for (int i=0; i < predkeys.length; i++) { int oid = - * predkeys[i]; activeOutcomes[id] = oid; activeParams[id] = - * PARAMS[pid].getParams(oid); id++; } - */ - sortPreds[pid] = new ComparablePredicate(PRED_LABELS[pid], - predkeys, activeParams); - } - - Arrays.sort(sortPreds); - return sortPreds; - } - - /** - * Compresses outcome patterns. - * - * @return A {@link List} of {@link List<ComparablePredicate>} that represent - * the remaining outcomes patterns. - */ - protected List<List<ComparablePredicate>> compressOutcomes(ComparablePredicate[] sorted) { - List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); - if (sorted.length > 0) { - ComparablePredicate cp = sorted[0]; - List<ComparablePredicate> newGroup = new ArrayList<>(); - for (ComparablePredicate comparablePredicate : sorted) { - if (cp.compareTo(comparablePredicate) == 0) { - newGroup.add(comparablePredicate); - } else { - cp = comparablePredicate; - outcomePatterns.add(newGroup); - newGroup = new ArrayList<>(); - newGroup.add(comparablePredicate); - } - } - outcomePatterns.add(newGroup); - } - return outcomePatterns; - } } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java index eb329029..a434ce4c 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesModelWriter.java @@ -18,14 +18,10 @@ package opennlp.tools.ml.naivebayes; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import opennlp.tools.ml.AbstractMLModelWriter; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.AbstractModelWriter; import opennlp.tools.ml.model.ComparablePredicate; @@ -40,18 +36,18 @@ import opennlp.tools.ml.model.Context; * * @see NaiveBayesModel * @see AbstractModelWriter + * @see AbstractMLModelWriter */ -public abstract class NaiveBayesModelWriter extends AbstractModelWriter { - - private static final Logger logger = LoggerFactory.getLogger(NaiveBayesModelWriter.class); - - protected Context[] PARAMS; - protected String[] OUTCOME_LABELS; - protected String[] PRED_LABELS; - int numOutcomes; +public abstract class NaiveBayesModelWriter extends AbstractMLModelWriter { + /** + * Initializes a {@link NaiveBayesModelWriter} for a + * {@link AbstractModel NaiveBayes model}. + * + * @param model The {@link AbstractModel NaiveBayes model} to be written. + */ public NaiveBayesModelWriter(AbstractModel model) { - + super(); Object[] data = model.getDataStructures(); this.numOutcomes = model.getNumOutcomes(); PARAMS = (Context[]) data[0]; @@ -72,10 +68,9 @@ public abstract class NaiveBayesModelWriter extends AbstractModelWriter { } /** - * Sorts and optimizes the model parameters. - * - * @return A {@link ComparablePredicate[]}. + * {@inheritDoc} */ + @Override protected ComparablePredicate[] sortValues() { ComparablePredicate[] sortPreds = new ComparablePredicate[PARAMS.length]; @@ -103,105 +98,21 @@ public abstract class NaiveBayesModelWriter extends AbstractModelWriter { return sortPreds; } - /** - * Compresses outcome patterns. - * - * @return A {@link List} of {@link List<ComparablePredicate>} that represent - * the remaining outcomes patterns. - */ - protected List<List<ComparablePredicate>> compressOutcomes(ComparablePredicate[] sorted) { - List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); - if (sorted.length > 0) { - ComparablePredicate cp = sorted[0]; - List<ComparablePredicate> newGroup = new ArrayList<>(); - for (ComparablePredicate comparablePredicate : sorted) { - if (cp.compareTo(comparablePredicate) == 0) { - newGroup.add(comparablePredicate); - } else { - cp = comparablePredicate; - outcomePatterns.add(newGroup); - newGroup = new ArrayList<>(); - newGroup.add(comparablePredicate); - } - } - outcomePatterns.add(newGroup); - } - return outcomePatterns; - } - - /** - * Computes outcome patterns via {@link ComparablePredicate[] predicates}. - * - * @return A {@link List} of {@link List<ComparablePredicate>} that represent - * the outcomes patterns. - */ - protected List<List<ComparablePredicate>> computeOutcomePatterns(ComparablePredicate[] sorted) { - ComparablePredicate cp = sorted[0]; - List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); - List<ComparablePredicate> newGroup = new ArrayList<>(); - for (ComparablePredicate predicate : sorted) { - if (cp.compareTo(predicate) == 0) { - newGroup.add(predicate); - } else { - cp = predicate; - outcomePatterns.add(newGroup); - newGroup = new ArrayList<>(); - newGroup.add(predicate); - } - } - outcomePatterns.add(newGroup); - logger.info("{} outcome patterns", outcomePatterns.size()); - return outcomePatterns; - } - /** * Writes the {@link AbstractModel perceptron model}, using the * {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link #writeInt(int)}} * methods implemented by extending classes. * * <p>If you wish to create a {@link NaiveBayesModelWriter} which uses a different - * structure, it will be necessary to override the {@code #persist()} method in + * structure, it will be necessary to override the {@link #persist()} method in * addition to implementing the {@code writeX(..)} methods. * * @throws IOException Thrown if IO errors occurred. */ @Override public void persist() throws IOException { - // the type of model (NaiveBayes) writeUTF("NaiveBayes"); - - // the mapping from outcomes to their integer indexes - writeInt(OUTCOME_LABELS.length); - - for (String label : OUTCOME_LABELS) { - writeUTF(label); - } - - // the mapping from predicates to the outcomes they contributed to. - // The sorting is done so that we actually can write this out more - // compactly than as the entire list. - ComparablePredicate[] sorted = sortValues(); - List<List<ComparablePredicate>> compressed = computeOutcomePatterns(sorted); - - writeInt(compressed.size()); - - for (List<ComparablePredicate> a : compressed) { - writeUTF(a.size() + a.get(0).toString()); - } - - // the mapping from predicate names to their integer indexes - writeInt(sorted.length); - - for (ComparablePredicate s : sorted) { - writeUTF(s.name); - } - - // write out the parameters - for (ComparablePredicate comparablePredicate : sorted) - for (int j = 0; j < comparablePredicate.params.length; j++) - writeDouble(comparablePredicate.params[j]); - - close(); + super.persist(); } } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java index 87950c01..be233e46 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronModelWriter.java @@ -18,14 +18,13 @@ package opennlp.tools.ml.perceptron; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import opennlp.tools.ml.AbstractMLModelWriter; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.AbstractModelWriter; import opennlp.tools.ml.model.ComparablePredicate; @@ -40,14 +39,11 @@ import opennlp.tools.ml.model.Context; * * @see PerceptronModel * @see AbstractModelWriter + * @see AbstractMLModelWriter */ -public abstract class PerceptronModelWriter extends AbstractModelWriter { +public abstract class PerceptronModelWriter extends AbstractMLModelWriter { private static final Logger logger = LoggerFactory.getLogger(PerceptronModelWriter.class); - protected Context[] PARAMS; - protected String[] OUTCOME_LABELS; - protected String[] PRED_LABELS; - private final int numOutcomes; /** * Initializes a {@link PerceptronModelWriter} for a @@ -56,9 +52,9 @@ public abstract class PerceptronModelWriter extends AbstractModelWriter { * @param model The {@link AbstractModel perceptron model} to be written. */ public PerceptronModelWriter(AbstractModel model) { - + super(); Object[] data = model.getDataStructures(); - this.numOutcomes = model.getNumOutcomes(); + numOutcomes = model.getNumOutcomes(); PARAMS = (Context[]) data[0]; @SuppressWarnings("unchecked") @@ -77,11 +73,9 @@ public abstract class PerceptronModelWriter extends AbstractModelWriter { } /** - * Sorts and optimizes the model parameters. Thereby, parameters with - * {@code 0} weight and predicates with no parameters are removed. - * - * @return A {@link ComparablePredicate[]}. + * {@inheritDoc} */ + @Override protected ComparablePredicate[] sortValues() { ComparablePredicate[] sortPreds; ComparablePredicate[] tmpPreds = new ComparablePredicate[PARAMS.length]; @@ -120,79 +114,21 @@ public abstract class PerceptronModelWriter extends AbstractModelWriter { return sortPreds; } - /** - * Computes outcome patterns via {@link ComparablePredicate[] predicates}. - * - * @return A {@link List} of {@link List<ComparablePredicate>} that represent - * the outcomes patterns. - */ - protected List<List<ComparablePredicate>> computeOutcomePatterns(ComparablePredicate[] sorted) { - ComparablePredicate cp = sorted[0]; - List<List<ComparablePredicate>> outcomePatterns = new ArrayList<>(); - List<ComparablePredicate> newGroup = new ArrayList<>(); - for (ComparablePredicate predicate : sorted) { - if (cp.compareTo(predicate) == 0) { - newGroup.add(predicate); - } else { - cp = predicate; - outcomePatterns.add(newGroup); - newGroup = new ArrayList<>(); - newGroup.add(predicate); - } - } - outcomePatterns.add(newGroup); - logger.info("{} outcome patterns", outcomePatterns.size()); - return outcomePatterns; - } - /** * Writes the {@link AbstractModel perceptron model}, using the * {@link #writeUTF(String)}, {@link #writeDouble(double)}, or {@link #writeInt(int)}} * methods implemented by extending classes. * * <p>If you wish to create a {@link PerceptronModelWriter} which uses a different - * structure, it will be necessary to override the {@code #persist()} method in + * structure, it will be necessary to override the {@link #persist()} method in * addition to implementing the {@code writeX(..)} methods. * * @throws IOException Thrown if IO errors occurred. */ @Override public void persist() throws IOException { - // the type of model (Perceptron) writeUTF("Perceptron"); - - // the mapping from outcomes to their integer indexes - writeInt(OUTCOME_LABELS.length); - - for (String label : OUTCOME_LABELS) { - writeUTF(label); - } - - // the mapping from predicates to the outcomes they contributed to. - // The sorting is done so that we actually can write this out more - // compactly than as the entire list. - ComparablePredicate[] sorted = sortValues(); - List<List<ComparablePredicate>> compressed = computeOutcomePatterns(sorted); - - writeInt(compressed.size()); - - for (List<ComparablePredicate> a : compressed) { - writeUTF(a.size() + a.get(0).toString()); - } - - // the mapping from predicate names to their integer indexes - writeInt(sorted.length); - - for (ComparablePredicate s : sorted) { - writeUTF(s.name); - } - - // write out the parameters - for (ComparablePredicate comparablePredicate : sorted) - for (int j = 0; j < comparablePredicate.params.length; j++) - writeDouble(comparablePredicate.params[j]); - - close(); + super.persist(); } } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java b/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java index 4b5b5573..bd4de105 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/PrepAttachDataUtil.java @@ -23,36 +23,39 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Assertions; import opennlp.tools.ml.model.Event; import opennlp.tools.ml.model.MaxentModel; -import opennlp.tools.ml.perceptron.PerceptronPrepAttachTest; import opennlp.tools.util.ObjectStream; import opennlp.tools.util.ObjectStreamUtils; public class PrepAttachDataUtil { - private static List<Event> readPpaFile(String filename) throws IOException { + /* Caches ppa files as List<Event> via their name (key) */ + private static final Map<String, List<Event>> PPA_FILE_EVENTS = new HashMap<>(); - List<Event> events = new ArrayList<>(); - - try (InputStream in = PerceptronPrepAttachTest.class.getResourceAsStream("/data/ppa/" + - filename)) { - BufferedReader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); - String line; - while ((line = reader.readLine()) != null) { - String[] items = line.split("\\s+"); - String label = items[5]; - String[] context = {"verb=" + items[1], "noun=" + items[2], - "prep=" + items[3], "prep_obj=" + items[4]}; - events.add(new Event(label, context)); + private static List<Event> readPpaFile(String filename) throws IOException { + if (!PPA_FILE_EVENTS.containsKey(filename)) { + List<Event> events = new ArrayList<>(); + try (InputStream in = PrepAttachDataUtil.class.getResourceAsStream("/data/ppa/" + filename); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + String[] items = line.split("\\s+"); + String label = items[5]; + String[] context = {"verb=" + items[1], "noun=" + items[2], + "prep=" + items[3], "prep_obj=" + items[4]}; + events.add(new Event(label, context)); + } + PPA_FILE_EVENTS.put(filename, events); } } - - return events; + return PPA_FILE_EVENTS.get(filename); } public static ObjectStream<Event> createTrainingStream() throws IOException { diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java new file mode 100644 index 00000000..ce275859 --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/AbstractNaiveBayesTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.ml.naivebayes; + +import java.util.ArrayList; +import java.util.List; + +import opennlp.tools.ml.model.Event; +import opennlp.tools.util.ObjectStream; +import opennlp.tools.util.ObjectStreamUtils; + +public class AbstractNaiveBayesTest { + + protected ObjectStream<Event> createTrainingStream() { + List<Event> trainingEvents = new ArrayList<>(); + + String label1 = "politics"; + String[] context1 = {"bow=the", "bow=united", "bow=nations"}; + trainingEvents.add(new Event(label1, context1)); + + String label2 = "politics"; + String[] context2 = {"bow=the", "bow=united", "bow=states", "bow=and"}; + trainingEvents.add(new Event(label2, context2)); + + String label3 = "sports"; + String[] context3 = {"bow=manchester", "bow=united"}; + trainingEvents.add(new Event(label3, context3)); + + String label4 = "sports"; + String[] context4 = {"bow=manchester", "bow=and", "bow=barca"}; + trainingEvents.add(new Event(label4, context4)); + + return ObjectStreamUtils.createObjectStream(trainingEvents); + } +} diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java index a22861ab..0aacb370 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.java @@ -18,13 +18,14 @@ package opennlp.tools.ml.naivebayes; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; +import java.util.stream.Stream; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import opennlp.tools.ml.AbstractTrainer; import opennlp.tools.ml.model.AbstractDataIndexer; @@ -32,87 +33,45 @@ import opennlp.tools.ml.model.DataIndexer; import opennlp.tools.ml.model.Event; import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.ml.model.TwoPassDataIndexer; -import opennlp.tools.util.ObjectStream; -import opennlp.tools.util.ObjectStreamUtils; import opennlp.tools.util.TrainingParameters; /** * Test for naive bayes classification correctness without smoothing */ -public class NaiveBayesCorrectnessTest { +public class NaiveBayesCorrectnessTest extends AbstractNaiveBayesTest { private DataIndexer testDataIndexer; @BeforeEach - void initIndexer() { + void initIndexer() throws IOException { TrainingParameters trainingParameters = new TrainingParameters(); trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); testDataIndexer = new TwoPassDataIndexer(); testDataIndexer.init(trainingParameters, new HashMap<>()); - } - - @Test - void testNaiveBayes1() throws IOException { - testDataIndexer.index(createTrainingStream()); - NaiveBayesModel model = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - String label = "politics"; - String[] context = {"bow=united", "bow=nations"}; - Event event = new Event(label, context); - - // testModel(model, event, 1.0); // Expected value without smoothing - testModel(model, event, 0.9681650180264167); // Expected value with smoothing - } - @Test - void testNaiveBayes2() throws IOException { - - testDataIndexer.index(createTrainingStream()); - NaiveBayesModel model = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - String label = "sports"; - String[] context = {"bow=manchester", "bow=united"}; + @ParameterizedTest + @MethodSource("provideLabelsWithContextAndProb") + void testNaiveBayes(String label, String[] context, double expectedProb) { + NaiveBayesModel model = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); Event event = new Event(label, context); - - // testModel(model, event, 1.0); // Expected value without smoothing - testModel(model, event, 0.9658833555831029); // Expected value with smoothing - + + testModel(model, event, expectedProb); // Expected value with smoothing } - @Test - void testNaiveBayes3() throws IOException { - - testDataIndexer.index(createTrainingStream()); - NaiveBayesModel model = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - String label = "politics"; - String[] context = {"bow=united"}; - Event event = new Event(label, context); - - //testModel(model, event, 2.0/3.0); // Expected value without smoothing - testModel(model, event, 0.6655036407766989); // Expected value with smoothing - - } - - @Test - void testNaiveBayes4() throws IOException { - - testDataIndexer.index(createTrainingStream()); - NaiveBayesModel model = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - String label = "politics"; - String[] context = {}; - Event event = new Event(label, context); - - testModel(model, event, 7.0 / 12.0); - + /* + * Produces a stream of <label|context> pairs for parameterized unit tests. + */ + private static Stream<Arguments> provideLabelsWithContextAndProb() { + return Stream.of( + // Example 1: + Arguments.of("politics" , new String[] {"bow=united", "bow=nations"}, 0.9681650180264167), + Arguments.of("sports", new String[] {"bow=manchester", "bow=united"}, 0.9658833555831029), + Arguments.of("politics", new String[] {"bow=united"}, 0.6655036407766989), + Arguments.of("politics", new String[] {}, 7.0 / 12.0) + ); } private void testModel(MaxentModel model, Event event, double higher_probability) { @@ -134,26 +93,4 @@ public class NaiveBayesCorrectnessTest { } } - public static ObjectStream<Event> createTrainingStream() { - List<Event> trainingEvents = new ArrayList<>(); - - String label1 = "politics"; - String[] context1 = {"bow=the", "bow=united", "bow=nations"}; - trainingEvents.add(new Event(label1, context1)); - - String label2 = "politics"; - String[] context2 = {"bow=the", "bow=united", "bow=states", "bow=and"}; - trainingEvents.add(new Event(label2, context2)); - - String label3 = "sports"; - String[] context3 = {"bow=manchester", "bow=united"}; - trainingEvents.add(new Event(label3, context3)); - - String label4 = "sports"; - String[] context4 = {"bow=manchester", "bow=and", "bow=barca"}; - trainingEvents.add(new Event(label4, context4)); - - return ObjectStreamUtils.createObjectStream(trainingEvents); - } - } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java index 64de26ab..6c135c9d 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesModelReadWriteTest.java @@ -18,6 +18,7 @@ package opennlp.tools.ml.naivebayes; import java.io.File; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; @@ -34,24 +35,24 @@ import opennlp.tools.ml.model.TwoPassDataIndexer; import opennlp.tools.util.TrainingParameters; /** - * Tests for persisting and reading naive bayes models + * Tests for persisting and reading naive bayes models. */ -public class NaiveBayesModelReadWriteTest { +public class NaiveBayesModelReadWriteTest extends AbstractNaiveBayesTest { private DataIndexer testDataIndexer; @BeforeEach - void initIndexer() { + void initIndexer() throws IOException { TrainingParameters trainingParameters = new TrainingParameters(); trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); testDataIndexer = new TwoPassDataIndexer(); testDataIndexer.init(trainingParameters, new HashMap<>()); + testDataIndexer.index(createTrainingStream()); } @Test - void testBinaryModelPersistence() throws Exception { - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); + void testBinaryModelPersistence() throws IOException { NaiveBayesModel model = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); Path tempFile = Files.createTempFile("bnb-", ".bin"); File file = tempFile.toFile(); @@ -69,7 +70,6 @@ public class NaiveBayesModelReadWriteTest { @Test void testTextModelPersistence() throws Exception { - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); NaiveBayesModel model = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); Path tempFile = Files.createTempFile("ptnb-", ".txt"); File file = tempFile.toFile(); diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java index 904e7549..06572abb 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesPrepAttachTest.java @@ -30,8 +30,10 @@ import opennlp.tools.ml.PrepAttachDataUtil; import opennlp.tools.ml.TrainerFactory; import opennlp.tools.ml.model.AbstractDataIndexer; import opennlp.tools.ml.model.DataIndexer; +import opennlp.tools.ml.model.Event; import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.ml.model.TwoPassDataIndexer; +import opennlp.tools.util.ObjectStream; import opennlp.tools.util.TrainingParameters; /** @@ -39,20 +41,23 @@ import opennlp.tools.util.TrainingParameters; */ public class NaiveBayesPrepAttachTest { - private DataIndexer testDataIndexer; + private ObjectStream<Event> trainingStream; @BeforeEach - void initIndexer() { - TrainingParameters trainingParameters = new TrainingParameters(); - trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); - trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); - testDataIndexer = new TwoPassDataIndexer(); - testDataIndexer.init(trainingParameters, new HashMap<>()); + void initIndexer() throws IOException { + trainingStream = PrepAttachDataUtil.createTrainingStream(); + Assertions.assertNotNull(trainingStream); } @Test void testNaiveBayesOnPrepAttachData() throws IOException { - testDataIndexer.index(PrepAttachDataUtil.createTrainingStream()); + TrainingParameters trainingParameters = new TrainingParameters(); + trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); + trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); + DataIndexer testDataIndexer = new TwoPassDataIndexer(); + testDataIndexer.init(trainingParameters, new HashMap<>()); + testDataIndexer.index(trainingStream); + MaxentModel model = new NaiveBayesTrainer().trainModel(testDataIndexer); Assertions.assertTrue(model instanceof NaiveBayesModel); PrepAttachDataUtil.testModel(model, 0.7897994553107205); @@ -65,7 +70,7 @@ public class NaiveBayesPrepAttachTest { trainParams.put(AbstractTrainer.CUTOFF_PARAM, 1); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null); - MaxentModel model = trainer.train(PrepAttachDataUtil.createTrainingStream()); + MaxentModel model = trainer.train(trainingStream); Assertions.assertTrue(model instanceof NaiveBayesModel); PrepAttachDataUtil.testModel(model, 0.7897994553107205); } @@ -77,7 +82,7 @@ public class NaiveBayesPrepAttachTest { trainParams.put(AbstractTrainer.CUTOFF_PARAM, 5); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null); - MaxentModel model = trainer.train(PrepAttachDataUtil.createTrainingStream()); + MaxentModel model = trainer.train(trainingStream); Assertions.assertTrue(model instanceof NaiveBayesModel); PrepAttachDataUtil.testModel(model, 0.7945035899975241); } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java index 31f37426..e0abd823 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.java @@ -26,10 +26,14 @@ import java.io.StringWriter; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; +import java.util.stream.Stream; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import opennlp.tools.ml.AbstractTrainer; import opennlp.tools.ml.model.AbstractDataIndexer; @@ -39,101 +43,51 @@ import opennlp.tools.ml.model.TwoPassDataIndexer; import opennlp.tools.util.TrainingParameters; /** - * Test for naive bayes classification correctness without smoothing + * Test for naive bayes classification correctness without smoothing. */ -public class NaiveBayesSerializedCorrectnessTest { +public class NaiveBayesSerializedCorrectnessTest extends AbstractNaiveBayesTest { private DataIndexer testDataIndexer; @BeforeEach - void initIndexer() { + void initIndexer() throws IOException { TrainingParameters trainingParameters = new TrainingParameters(); trainingParameters.put(AbstractTrainer.CUTOFF_PARAM, 1); trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false); testDataIndexer = new TwoPassDataIndexer(); testDataIndexer.init(trainingParameters, new HashMap<>()); + testDataIndexer.index(createTrainingStream()); } - @Test - void testNaiveBayes1() throws IOException { - - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); - NaiveBayesModel model1 = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - NaiveBayesModel model2 = persistedModel(model1); - - String label = "politics"; - String[] context = {"bow=united", "bow=nations"}; - Event event = new Event(label, context); - - testModelOutcome(model1, model2, event); - - } - - @Test - void testNaiveBayes2() throws IOException { - - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); - NaiveBayesModel model1 = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - NaiveBayesModel model2 = persistedModel(model1); - - String label = "sports"; - String[] context = {"bow=manchester", "bow=united"}; - Event event = new Event(label, context); - - testModelOutcome(model1, model2, event); - - } - - @Test - void testNaiveBayes3() throws IOException { - - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); - NaiveBayesModel model1 = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - + @ParameterizedTest + @MethodSource("provideLabelsWithContext") + void testNaiveBayes(String label, String[] context) throws IOException { + NaiveBayesModel model1 = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); NaiveBayesModel model2 = persistedModel(model1); - - String label = "politics"; - String[] context = {"bow=united"}; Event event = new Event(label, context); - testModelOutcome(model1, model2, event); - } - @Test - void testNaiveBayes4() throws IOException { - - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); - NaiveBayesModel model1 = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - - NaiveBayesModel model2 = persistedModel(model1); - - String label = "politics"; - String[] context = {}; - Event event = new Event(label, context); - - testModelOutcome(model1, model2, event); - + /* + * Produces a stream of <label|context> pairs for parameterized unit tests. + */ + private static Stream<Arguments> provideLabelsWithContext() { + return Stream.of( + // Example 1: + Arguments.of("politics" , new String[] {"bow=united", "bow=nations"}), + Arguments.of("sports", new String[] {"bow=manchester", "bow=united"}), + Arguments.of("politics", new String[] {"bow=united"}), + Arguments.of("politics", new String[] {}) + ); } - @Test void testPlainTextModel() throws IOException { - testDataIndexer.index(NaiveBayesCorrectnessTest.createTrainingStream()); - NaiveBayesModel model1 = - (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); - + NaiveBayesModel model1 = (NaiveBayesModel) new NaiveBayesTrainer().trainModel(testDataIndexer); StringWriter sw1 = new StringWriter(); - NaiveBayesModelWriter modelWriter = - new PlainTextNaiveBayesModelWriter(model1, new BufferedWriter(sw1)); + NaiveBayesModelWriter modelWriter = new PlainTextNaiveBayesModelWriter(model1, new BufferedWriter(sw1)); modelWriter.persist(); NaiveBayesModelReader reader = @@ -147,10 +101,9 @@ public class NaiveBayesSerializedCorrectnessTest { modelWriter.persist(); Assertions.assertEquals(sw1.toString(), sw2.toString()); - } - protected static NaiveBayesModel persistedModel(NaiveBayesModel model) throws IOException { + private static NaiveBayesModel persistedModel(NaiveBayesModel model) throws IOException { Path tempFilePath = Files.createTempFile("ptnb-", ".bin"); File file = tempFilePath.toFile(); try { @@ -164,17 +117,15 @@ public class NaiveBayesSerializedCorrectnessTest { } } - protected static void testModelOutcome(NaiveBayesModel model1, NaiveBayesModel model2, Event event) { + private static void testModelOutcome(NaiveBayesModel model1, NaiveBayesModel model2, Event event) { String[] labels1 = extractLabels(model1); String[] labels2 = extractLabels(model2); Assertions.assertArrayEquals(labels1, labels2); - double[] outcomes1 = model1.eval(event.getContext()); double[] outcomes2 = model2.eval(event.getContext()); Assertions.assertArrayEquals(outcomes1, outcomes2, 0.000000000001); - } private static String[] extractLabels(NaiveBayesModel model) { diff --git a/pom.xml b/pom.xml index 3c95fa74..828d06a5 100644 --- a/pom.xml +++ b/pom.xml @@ -326,7 +326,7 @@ <configuration> <failOnUnsupportedJava>false</failOnUnsupportedJava> <bundledSignatures> - <bundledSignature>jdk-deprecated</bundledSignature> + <bundledSignature>jdk-deprecated-${java.version}</bundledSignature> <bundledSignature>jdk-non-portable</bundledSignature> </bundledSignatures> </configuration>
