This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch OPENNLP-1539
in repository https://gitbox.apache.org/repos/asf/opennlp.git

commit 8ef1a5a2df990276af1fea5f96b356be755e74e1
Author: Richard Zowalla <[email protected]>
AuthorDate: Thu May 23 16:59:49 2024 +0200

    OPENNLP-1539 - Introduce parameter for POSTaggerME to configure output POS 
tag format
---
 .../java/opennlp/tools/postag/POSTagFormat.java}   |  38 +---
 .../opennlp/tools/postag/POSTagFormatMapper.java   | 207 +++++++++++++++++++++
 .../java/opennlp/tools/postag/POSTaggerME.java     |  92 +++++----
 .../tools/namefind/TokenNameFinderModelTest.java   |   2 +-
 .../java/opennlp/tools/postag/POSModelTest.java    |   4 +-
 .../java/opennlp/tools/postag/POSTaggerMETest.java |  57 ++++--
 .../POSTaggerNameFeatureGeneratorTest.java         |   2 +-
 7 files changed, 321 insertions(+), 81 deletions(-)

diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
similarity index 54%
rename from opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java
rename to opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
index 9b521ce8..ddb9cc5f 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
@@ -14,38 +14,14 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package opennlp.tools.postag;
 
-import java.io.IOException;
-
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-
-public class POSTaggerMEIT {
-
-  private static POSTagger tagger;
-
-  @BeforeAll
-  public static void prepare() throws IOException {
-    tagger = new POSTaggerME("en");
-  }
-
-  @Test
-  void testPOSTagger() {
-
-    String[] tags = tagger.tag(new String[] {
-        "The",
-        "driver",
-        "got",
-        "badly",
-        "injured",
-        "."});
-
-    // TODO OPENNLP-1539 Adjust this depending on the POSFormat
-    String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
-    Assertions.assertArrayEquals(expected, tags);
-  }
+/**
+ * Defines the format for part-of-speech tagging, i.e.
+ * <a 
href="https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html";>PENN</a>
+ * or <a href="https://universaldependencies.org/u/feat/index.html";>UD</a> 
format.
+ */
+public enum POSTagFormat {
 
+  UD, PENN, UNKNOWN
 }
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
new file mode 100644
index 00000000..3001dfd1
--- /dev/null
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.tools.postag;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A mapping implementation for converting between different POS tag formats.
+ * This class supports conversion between Penn Treebank (PENN) and Universal 
Dependencies (UD) formats.
+ * The conversion is based on the <a 
href="https://universaldependencies.org/tagset-conversion/en-penn-uposf.html";>Universal
 Dependencies conversion table.</a>
+ * Please note that when converting from UD to Penn format, there may be 
ambiguity in some cases.
+ */
+public class POSTagFormatMapper {
+
+  private static final Logger logger = 
LoggerFactory.getLogger(POSTagFormatMapper.class);
+
+  private static final Map<String, String> CONVERSION_TABLE_PENN_TO_UD = new 
HashMap<>();
+  private static final Map<String, String> CONVERSION_TABLE_UD_TO_PENN = new 
HashMap<>();
+
+  static {
+    /*
+     * This is a conversion table to convert PENN to UD format as described in
+     * https://universaldependencies.org/tagset-conversion/en-penn-uposf.html
+     */
+    CONVERSION_TABLE_PENN_TO_UD.put("#", "SYM");
+    CONVERSION_TABLE_PENN_TO_UD.put("$", "SYM");
+    CONVERSION_TABLE_PENN_TO_UD.put("''", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put(",", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put("-LRB-", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put("-RRB-", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put(".", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put(":", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put("AFX", "ADJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("CC", "CCONJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("CD", "NUM");
+    CONVERSION_TABLE_PENN_TO_UD.put("DT", "DET");
+    CONVERSION_TABLE_PENN_TO_UD.put("EX", "PRON");
+    CONVERSION_TABLE_PENN_TO_UD.put("FW", "X");
+    CONVERSION_TABLE_PENN_TO_UD.put("HYPH", "PUNCT");
+    CONVERSION_TABLE_PENN_TO_UD.put("IN", "ADP");
+    CONVERSION_TABLE_PENN_TO_UD.put("JJ", "ADJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("JJR", "ADJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("JJS", "ADJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("LS", "X");
+    CONVERSION_TABLE_PENN_TO_UD.put("MD", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("NIL", "X");
+    CONVERSION_TABLE_PENN_TO_UD.put("NN", "NOUN");
+    CONVERSION_TABLE_PENN_TO_UD.put("NNP", "PROPN");
+    CONVERSION_TABLE_PENN_TO_UD.put("NNPS", "PROPN");
+    CONVERSION_TABLE_PENN_TO_UD.put("NNS", "NOUN");
+    CONVERSION_TABLE_PENN_TO_UD.put("PDT", "DET");
+    CONVERSION_TABLE_PENN_TO_UD.put("POS", "PART");
+    CONVERSION_TABLE_PENN_TO_UD.put("PRP", "PRON");
+    CONVERSION_TABLE_PENN_TO_UD.put("PRP$", "DET");
+    CONVERSION_TABLE_PENN_TO_UD.put("RB", "ADV");
+    CONVERSION_TABLE_PENN_TO_UD.put("RBR", "ADV");
+    CONVERSION_TABLE_PENN_TO_UD.put("RBS", "ADV");
+    CONVERSION_TABLE_PENN_TO_UD.put("RP", "ADP");
+    CONVERSION_TABLE_PENN_TO_UD.put("SYM", "SYM");
+    CONVERSION_TABLE_PENN_TO_UD.put("TO", "PART");
+    CONVERSION_TABLE_PENN_TO_UD.put("UH", "INTJ");
+    CONVERSION_TABLE_PENN_TO_UD.put("VB", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("VBD", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("VBG", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("VBN", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("VBP", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("VBZ", "VERB");
+    CONVERSION_TABLE_PENN_TO_UD.put("WDT", "DET");
+    CONVERSION_TABLE_PENN_TO_UD.put("WP", "PRON");
+    CONVERSION_TABLE_PENN_TO_UD.put("WP$", "DET");
+    CONVERSION_TABLE_PENN_TO_UD.put("WRB", "ADV");
+
+    /*
+     * Note: The back conversion might lose information.
+     */
+    CONVERSION_TABLE_UD_TO_PENN.put("ADJ", "JJ");
+    CONVERSION_TABLE_UD_TO_PENN.put("ADP", "IN");
+    CONVERSION_TABLE_UD_TO_PENN.put("ADV", "RB");
+    CONVERSION_TABLE_UD_TO_PENN.put("AUX", "MD");
+    CONVERSION_TABLE_UD_TO_PENN.put("CCONJ", "CC");
+    CONVERSION_TABLE_UD_TO_PENN.put("DET", "DT");
+    CONVERSION_TABLE_UD_TO_PENN.put("INTJ", "UH");
+    CONVERSION_TABLE_UD_TO_PENN.put("NOUN", "NN");
+    CONVERSION_TABLE_UD_TO_PENN.put("NUM", "CD");
+    CONVERSION_TABLE_UD_TO_PENN.put("PART", "RP");
+    CONVERSION_TABLE_UD_TO_PENN.put("PRON", "PRP");
+    CONVERSION_TABLE_UD_TO_PENN.put("PROPN", "NNP");
+    CONVERSION_TABLE_UD_TO_PENN.put("PUNCT", ".");
+    CONVERSION_TABLE_UD_TO_PENN.put("SCONJ", "IN");
+    CONVERSION_TABLE_UD_TO_PENN.put("SYM", "SYM");
+    CONVERSION_TABLE_UD_TO_PENN.put("VERB", "VB");
+    CONVERSION_TABLE_UD_TO_PENN.put("X", "FW");
+  }
+
+  private final POSTagFormat modelFormat;
+
+  protected POSTagFormatMapper(final String[] possibleOutcomes) {
+    Objects.requireNonNull(possibleOutcomes, "Outcomes must not be NULL.");
+    this.modelFormat = guessModelTagFormat(possibleOutcomes);
+  }
+
+  /**
+   * Converts a given tag to the specified format.
+   *
+   * @param tags a list of tags to be converted.
+   * @return the converted tag.
+   */
+  public String[] convertTags(List<String> tags) {
+    Objects.requireNonNull(tags, "Supplied tags must not be NULL.");
+    return tags.stream()
+        .map(this::convertTag)
+        .toArray(String[]::new);
+  }
+
+  /**
+   * Converts a given tag to the specified format.
+   *
+   * @param tag no restrictions on this parameter.
+   * @return the converted tag.
+   */
+  public String convertTag(String tag) {
+    switch (modelFormat) {
+      case UD -> {
+        return CONVERSION_TABLE_UD_TO_PENN.getOrDefault(tag, "?");
+      }
+      case PENN -> {
+        if ("NOUN".equals(tag)) {
+          logger.warn("Ambiguity detected: NN can be 'NN' or 'NNS' depending 
on the number. " +
+              "Returning 'NN'.");
+        }
+        if ("PART".equals(tag)) {
+          logger.warn("Ambiguity detected: PART can be 'RP' or 'TO'. Returning 
'TO'.");
+        }
+        if ("PROPN".equals(tag)) {
+          logger.warn("Ambiguity detected: Can be 'NNP' or 'NNPS. Returning 
'NNP'");
+        }
+        if ("PUNCT".equals(tag)) {
+          logger.warn("Ambiguity detected: PUNCT needs specific punctuation 
mapping. Returning '.'");
+        }
+        if ("VERB".equals(tag)) {
+          logger.warn("Ambiguity detected: VERB can be 'VB', 'VBD', 'VBG', 
'VBN', 'VBP', 'VBZ'. " +
+              "Returning 'VERB'.");
+        }
+        return CONVERSION_TABLE_PENN_TO_UD.getOrDefault(tag, "?");
+      }
+      default -> {
+        return tag;
+      }
+    }
+  }
+
+  /**
+   *
+   * @return The guessed {@link POSTagFormat}. Guaranteed to be not {@code 
null}.
+   */
+  public POSTagFormat getGuessedFormat() {
+    return this.modelFormat;
+  }
+
+  /**
+   * Guesses the {@link POSTagFormat} by using majority quorum.
+   * @param outcomes must not be {@code null}.
+   * @return the guessed {@link POSTagFormat}.
+   */
+  private POSTagFormat guessModelTagFormat(final String[] outcomes) {
+    int udMatches = 0;
+    int pennMatches = 0;
+
+    for (String outcome : outcomes) {
+      if (CONVERSION_TABLE_UD_TO_PENN.containsKey(outcome)) {
+        udMatches++;
+      }
+      if (CONVERSION_TABLE_PENN_TO_UD.containsKey(outcome)) {
+        pennMatches++;
+      }
+    }
+
+    if (udMatches > pennMatches) {
+      return POSTagFormat.UD;
+    } else if (pennMatches > udMatches) {
+      return POSTagFormat.PENN;
+    } else {
+      logger.warn("Detected an unknown POS format.");
+      return POSTagFormat.UNKNOWN;
+    }
+  }
+}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
index 0268f48b..56b77c32 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
@@ -19,6 +19,7 @@ package opennlp.tools.postag;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -83,16 +84,30 @@ public class POSTaggerME implements POSTagger {
 
   private final SequenceValidator<String> sequenceValidator;
 
+  private final POSTagFormat posTagFormat;
+  private final POSTagFormatMapper posTagFormatMapper;
+
   /**
    * Initializes a {@link POSTaggerME} by downloading a default model for a 
given
    * {@code language}.
    *
    * @param language An ISO conform language code.
-   *                 
    * @throws IOException Thrown if the model could not be downloaded or saved.
    */
   public POSTaggerME(String language) throws IOException {
-    this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, 
POSModel.class));
+    this(language, POSTagFormat.UD);
+  }
+
+  /**
+   * Initializes a {@link POSTaggerME} by downloading a default model for a 
given
+   * {@code language}.
+   *
+   * @param language An ISO conform language code.
+   * @param format   A valid {@link POSTagFormat}.
+   * @throws IOException Thrown if the model could not be downloaded or saved.
+   */
+  public POSTaggerME(String language, POSTagFormat format) throws IOException {
+    this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, 
POSModel.class), format);
   }
 
   /**
@@ -101,6 +116,17 @@ public class POSTaggerME implements POSTagger {
    * @param model A valid {@link POSModel}.
    */
   public POSTaggerME(POSModel model) {
+    this(model, POSTagFormat.UD);
+  }
+
+  /**
+   * Initializes a {@link POSTaggerME} with the provided {@link POSModel 
model}.
+   *
+   * @param model  A valid {@link POSModel}.
+   * @param format A valid {@link POSTagFormat}.
+   */
+  public POSTaggerME(POSModel model, POSTagFormat format) {
+    this.posTagFormat = format;
     POSTaggerFactory factory = model.getFactory();
 
     int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
@@ -121,12 +147,13 @@ public class POSTaggerME implements POSTagger {
 
     if (model.getPosSequenceModel() != null) {
       this.model = model.getPosSequenceModel();
-    }
-    else {
+    } else {
       this.model = new opennlp.tools.ml.BeamSearch<>(beamSize,
           model.getPosModel(), 0);
     }
 
+    this.posTagFormatMapper = new POSTagFormatMapper(getAllPosTags());
+
   }
 
   /**
@@ -144,16 +171,15 @@ public class POSTaggerME implements POSTagger {
   @Override
   public String[] tag(String[] sentence, Object[] additionalContext) {
     bestSequence = model.bestSequence(sentence, additionalContext, contextGen, 
sequenceValidator);
-    List<String> t = bestSequence.getOutcomes();
-    return t.toArray(new String[0]);
+    final List<String> t = bestSequence.getOutcomes();
+    return convertTags(t);
   }
 
   /**
    * Returns at most the specified {@code numTaggings} for the specified 
{@code sentence}.
    *
    * @param numTaggings The number of tagging to be returned.
-   * @param sentence An array of tokens which make up a sentence.
-   *
+   * @param sentence    An array of tokens which make up a sentence.
    * @return At most the specified number of taggings for the specified {@code 
sentence}.
    */
   public String[][] tag(int numTaggings, String[] sentence) {
@@ -162,11 +188,19 @@ public class POSTaggerME implements POSTagger {
     String[][] tags = new String[bestSequences.length][];
     for (int si = 0; si < tags.length; si++) {
       List<String> t = bestSequences[si].getOutcomes();
-      tags[si] = t.toArray(new String[0]);
+      tags[si] = convertTags(t);
     }
     return tags;
   }
 
+  private String[] convertTags(List<String> t) {
+    if (posTagFormatMapper.getGuessedFormat() == posTagFormat) {
+      return t.toArray(new String[0]);
+    } else {
+      return posTagFormatMapper.convertTags(t);
+    }
+  }
+
   @Override
   public Sequence[] topKSequences(String[] sentence) {
     return this.topKSequences(sentence, null);
@@ -194,10 +228,10 @@ public class POSTaggerME implements POSTagger {
   }
 
   public String[] getOrderedTags(List<String> words, List<String> tags, int 
index) {
-    return getOrderedTags(words,tags,index,null);
+    return getOrderedTags(words, tags, index, null);
   }
 
-  public String[] getOrderedTags(List<String> words, List<String> tags, int 
index,double[] tprobs) {
+  public String[] getOrderedTags(List<String> words, List<String> tags, int 
index, double[] tprobs) {
 
     if (modelPackage.getPosModel() != null) {
 
@@ -205,7 +239,7 @@ public class POSTaggerME implements POSTagger {
 
       double[] probs = posModel.eval(contextGen.getContext(index,
           words.toArray(new String[0]),
-          tags.toArray(new String[0]),null));
+          tags.toArray(new String[0]), null));
 
       String[] orderedTags = new String[probs.length];
       for (int i = 0; i < probs.length; i++) {
@@ -221,17 +255,16 @@ public class POSTaggerME implements POSTagger {
         }
         probs[max] = 0;
       }
-      return orderedTags;
-    }
-    else {
+      return convertTags(Arrays.stream(orderedTags).toList());
+    } else {
       throw new UnsupportedOperationException("This method can only be called 
if the "
           + "classification model is an event model!");
     }
   }
 
   public static POSModel train(String languageCode,
-      ObjectStream<POSSample> samples, TrainingParameters trainParams,
-      POSTaggerFactory posFactory) throws IOException {
+                               ObjectStream<POSSample> samples, 
TrainingParameters trainParams,
+                               POSTaggerFactory posFactory) throws IOException 
{
 
     int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, 
POSTaggerME.DEFAULT_BEAM_SIZE);
 
@@ -249,14 +282,12 @@ public class POSTaggerME implements POSTagger {
       EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams,
           manifestInfoEntries);
       posModel = trainer.train(es);
-    }
-    else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
+    } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
       POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, 
contextGenerator);
       EventModelSequenceTrainer<POSSample> trainer =
           TrainerFactory.getEventModelSequenceTrainer(trainParams, 
manifestInfoEntries);
       posModel = trainer.train(ss);
-    }
-    else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
+    } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
       SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
           trainParams, manifestInfoEntries);
 
@@ -264,15 +295,13 @@ public class POSTaggerME implements POSTagger {
 
       POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, 
contextGenerator);
       seqPosModel = trainer.train(ss);
-    }
-    else {
+    } else {
       throw new IllegalArgumentException("Trainer type is not supported: " + 
trainerType);
     }
 
     if (posModel != null) {
       return new POSModel(languageCode, posModel, beamSize, 
manifestInfoEntries, posFactory);
-    }
-    else {
+    } else {
       return new POSModel(languageCode, seqPosModel, manifestInfoEntries, 
posFactory);
     }
   }
@@ -282,9 +311,7 @@ public class POSTaggerME implements POSTagger {
    *
    * @param samples The {@link ObjectStream} to process.
    * @param cutoff  A non-negative cut-off value.
-   *
    * @return A valid {@link Dictionary} instance holding nGrams.
-   *
    * @throws IOException Thrown if IO errors occurred during dictionary 
construction.
    */
   public static Dictionary buildNGramDictionary(ObjectStream<POSSample> 
samples, int cutoff)
@@ -295,8 +322,9 @@ public class POSTaggerME implements POSTagger {
     while ((sample = samples.read()) != null) {
       String[] words = sample.getSentence();
 
-      if (words.length > 0)
+      if (words.length > 0) {
         ngramModel.add(new StringList(words), 1, 1);
+      }
     }
 
     ngramModel.cutoff(cutoff, Integer.MAX_VALUE);
@@ -308,13 +336,12 @@ public class POSTaggerME implements POSTagger {
    * Populates a {@link POSDictionary} from an {@link ObjectStream} of samples.
    *
    * @param samples The {@link ObjectStream} to process.
-   * @param dict The {@link MutableTagDictionary} to use during population.
+   * @param dict    The {@link MutableTagDictionary} to use during population.
    * @param cutoff  A non-negative cut-off value.
-   *
    * @throws IOException Thrown if IO errors occurred during dictionary 
construction.
    */
   public static void populatePOSDictionary(ObjectStream<POSSample> samples,
-      MutableTagDictionary dict, int cutoff) throws IOException {
+                                           MutableTagDictionary dict, int 
cutoff) throws IOException {
 
     logger.info("Expanding POS Dictionary ...");
     long start = System.nanoTime();
@@ -377,6 +404,7 @@ public class POSTaggerME implements POSTagger {
       }
     }
 
-    logger.info("... finished expanding POS Dictionary. [ {} ms]", 
(System.nanoTime() - start) / 1000000 );
+    logger.info("... finished expanding POS Dictionary. [ {} ms]", 
(System.nanoTime() - start) / 1000000);
   }
+
 }
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
 
b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
index 5379a1f3..c77b2874 100644
--- 
a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
@@ -57,7 +57,7 @@ public class TokenNameFinderModelTest extends 
AbstractModelLoaderTest {
     Path resourcesFolder = 
Files.createTempDirectory("resources").toAbsolutePath();
 
     // save a POS model there
-    POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+    POSModel posModel = 
POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT);
     Assertions.assertNotNull(posModel);
 
     File posModelFile = new File(resourcesFolder.toFile(), "pos-model.bin");
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java 
b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
index 14086c16..1565c45b 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
@@ -30,7 +30,7 @@ public class POSModelTest {
 
   @Test
   void testPOSModelSerializationMaxent() throws IOException {
-    POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+    POSModel posModel = 
POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT);
     Assertions.assertFalse(posModel.isLoadedFromSerialized());
 
     try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
@@ -45,7 +45,7 @@ public class POSModelTest {
 
   @Test
   void testPOSModelSerializationPerceptron() throws IOException {
-    POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.PERCEPTRON);
+    POSModel posModel = 
POSTaggerMETest.trainPennFormatPOSModel(ModelType.PERCEPTRON);
     Assertions.assertFalse(posModel.isLoadedFromSerialized());
     
     try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java 
b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
index 945de120..9f9b6dd1 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
@@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import opennlp.tools.EnabledWhenCDNAvailable;
 import opennlp.tools.formats.ResourceAsStreamFactory;
 import opennlp.tools.util.InputStreamFactory;
 import opennlp.tools.util.InsufficientTrainingDataException;
@@ -39,7 +40,7 @@ public class POSTaggerMETest {
 
   private static ObjectStream<POSSample> createSampleStream() throws 
IOException {
     InputStreamFactory in = new ResourceAsStreamFactory(POSTaggerMETest.class,
-        "/opennlp/tools/postag/AnnotatedSentences.txt");
+        "/opennlp/tools/postag/AnnotatedSentences.txt"); //PENN FORMAT
 
     return new WordTagSampleStream(new PlainTextByLineStream(in, 
StandardCharsets.UTF_8));
   }
@@ -49,7 +50,7 @@ public class POSTaggerMETest {
    *
    * @return {@link POSModel}
    */
-  public static POSModel trainPOSModel(ModelType type) throws IOException {
+  public static POSModel trainPennFormatPOSModel(ModelType type) throws 
IOException {
     TrainingParameters params = new TrainingParameters();
     params.put(TrainingParameters.ALGORITHM_PARAM, type.toString());
     params.put(TrainingParameters.ITERATIONS_PARAM, 100);
@@ -61,25 +62,53 @@ public class POSTaggerMETest {
 
   @Test
   void testPOSTagger() throws IOException {
-    POSModel posModel = trainPOSModel(ModelType.MAXENT);
+    final String[] sentence = {
+        "The",
+        "driver",
+        "got",
+        "badly",
+        "injured",
+        "."};
 
-    POSTagger tagger = new POSTaggerME(posModel);
+    final String[] expected = {"DT", "NN", "VBD", "RB", "VBN", "."};
+    testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT),
+        POSTagFormat.PENN), sentence, expected);
+  }
 
-    String[] tags = tagger.tag(new String[] {
+  @Test
+  void testPOSTaggerPENNtoUD() throws IOException {
+    final String[] sentence = {
         "The",
         "driver",
         "got",
         "badly",
         "injured",
-        "."});
-
-    Assertions.assertEquals(6, tags.length);
-    Assertions.assertEquals("DT", tags[0]);
-    Assertions.assertEquals("NN", tags[1]);
-    Assertions.assertEquals("VBD", tags[2]);
-    Assertions.assertEquals("RB", tags[3]);
-    Assertions.assertEquals("VBN", tags[4]);
-    Assertions.assertEquals(".", tags[5]);
+        "."};
+
+    final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+    //convert PENN to UD on the fly.
+    testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT),
+        POSTagFormat.UD), sentence, expected);
+  }
+
+  @Test
+  @EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org")
+  void testPOSTaggerDefault() throws IOException {
+    final String[] sentence = {
+        "The",
+        "driver",
+        "got",
+        "badly",
+        "injured",
+        "."};
+
+    final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+    //this downloads a UD model
+    testPOSTagger(new POSTaggerME("en"), sentence, expected);
+  }
+
+  private void testPOSTagger(POSTagger tagger, String[] sentences, String[] 
expectedTags) {
+    Assertions.assertArrayEquals(expectedTags, tagger.tag(sentences));
   }
 
   @Test
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
 
b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
index ccd1f9f2..c349ee71 100644
--- 
a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
@@ -33,7 +33,7 @@ public class POSTaggerNameFeatureGeneratorTest {
   @Test
   void testFeatureGeneration() throws IOException {
     POSTaggerNameFeatureGenerator fg = new POSTaggerNameFeatureGenerator(
-        POSTaggerMETest.trainPOSModel(ModelType.MAXENT));
+        POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT));
 
     String[] tokens = {"Hi", "Mike", ",", "it", "'s", "Stefanie", "Schmidt", 
"."};
     for (int i = 0; i < tokens.length; i++) {

Reply via email to