This is an automated email from the ASF dual-hosted git repository.

jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/master by this push:
     new 02eaffa4 OPENNLP-1407: Adding sentence detector to NameFinderDL. (#447)
02eaffa4 is described below

commit 02eaffa4cfb0f9c2c2ea1da864f8cc53d215e97d
Author: Jeff Zemerick <[email protected]>
AuthorDate: Sat Dec 10 09:18:17 2022 -0500

    OPENNLP-1407: Adding sentence detector to NameFinderDL. (#447)
---
 .../java/opennlp/dl/namefinder/NameFinderDL.java   | 194 +++++++++++----------
 .../opennlp/dl/namefinder/NameFinderDLEval.java    |  23 ++-
 2 files changed, 121 insertions(+), 96 deletions(-)

diff --git a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java 
b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index ff132897..319a2074 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -39,6 +39,7 @@ import opennlp.dl.InferenceOptions;
 import opennlp.dl.SpanEnd;
 import opennlp.dl.Tokens;
 import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.sentdetect.SentenceDetector;
 import opennlp.tools.tokenize.Tokenizer;
 import opennlp.tools.tokenize.WordpieceTokenizer;
 import opennlp.tools.util.Span;
@@ -54,23 +55,27 @@ public class NameFinderDL implements TokenNameFinder {
 
   public static final String I_PER = "I-PER";
   public static final String B_PER = "B-PER";
+  public static final String SEPARATOR = "[SEP]";
 
   protected final OrtSession session;
 
+  private final SentenceDetector sentenceDetector;
   private final Map<Integer, String> ids2Labels;
   private final Tokenizer tokenizer;
   private final Map<String, Integer> vocab;
   private final InferenceOptions inferenceOptions;
   protected final OrtEnvironment env;
 
-  public NameFinderDL(File model, File vocabulary, Map<Integer, String> 
ids2Labels) throws Exception {
+  public NameFinderDL(File model, File vocabulary, Map<Integer, String> 
ids2Labels,
+                      SentenceDetector sentenceDetector) throws Exception {
 
-    this(model, vocabulary, ids2Labels, new InferenceOptions());
+    this(model, vocabulary, ids2Labels, new InferenceOptions(), 
sentenceDetector);
 
   }
 
   public NameFinderDL(File model, File vocabulary, Map<Integer, String> 
ids2Labels,
-                      InferenceOptions inferenceOptions) throws Exception {
+                      InferenceOptions inferenceOptions,
+                      SentenceDetector sentenceDetector) throws Exception {
 
     this.env = OrtEnvironment.getEnvironment();
 
@@ -84,145 +89,156 @@ public class NameFinderDL implements TokenNameFinder {
     this.vocab = loadVocab(vocabulary);
     this.tokenizer = new WordpieceTokenizer(vocab.keySet());
     this.inferenceOptions = inferenceOptions;
+    this.sentenceDetector = sentenceDetector;
 
   }
 
   @Override
   public Span[] find(String[] input) {
 
-
-
-    /**
-     * So, it looks like inference is being done on the wordpiece tokens but 
then
-     * spans are being created from the whitespace tokens.
-     */
-
     final List<Span> spans = new LinkedList<>();
 
     // Join the tokens here because they will be tokenized using Wordpiece 
during inference.
     final String text = String.join(" ", input);
 
-    // The WordPiece tokenized text. This changes the spacing in the text.
-    final List<Tokens> wordpieceTokens = tokenize(text);
+    final String[] sentences = sentenceDetector.sentDetect(text);
 
-    for (final Tokens tokens : wordpieceTokens) {
+    for (String sentence : sentences) {
 
-      try {
+      // The WordPiece tokenized text. This changes the spacing in the text.
+      final List<Tokens> wordpieceTokens = tokenize(sentence);
 
-        // The inputs to the ONNX model.
-        final Map<String, OnnxTensor> inputs = new HashMap<>();
-        inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, 
LongBuffer.wrap(tokens.getIds()),
-            new long[] {1, tokens.getIds().length}));
+      for (final Tokens tokens : wordpieceTokens) {
 
-        if (inferenceOptions.isIncludeAttentionMask()) {
-          inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
-              LongBuffer.wrap(tokens.getMask()), new long[] {1, 
tokens.getMask().length}));
-        }
+        try {
 
-        if (inferenceOptions.isIncludeTokenTypeIds()) {
-          inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
-              LongBuffer.wrap(tokens.getTypes()), new long[] {1, 
tokens.getTypes().length}));
-        }
+          // The inputs to the ONNX model.
+          final Map<String, OnnxTensor> inputs = new HashMap<>();
+          inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, 
LongBuffer.wrap(tokens.getIds()),
+              new long[] {1, tokens.getIds().length}));
 
-        // The outputs from the model.
-        final float[][][] v = (float[][][]) 
session.run(inputs).get(0).getValue();
+          if (inferenceOptions.isIncludeAttentionMask()) {
+            inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+                LongBuffer.wrap(tokens.getMask()), new long[] {1, 
tokens.getMask().length}));
+          }
 
-        // Find consecutive B-PER and I-PER labels and combine the spans where 
necessary.
-        // There are also B-LOC and I-LOC tags for locations that might be 
useful at some point.
+          if (inferenceOptions.isIncludeTokenTypeIds()) {
+            inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+                LongBuffer.wrap(tokens.getTypes()), new long[] {1, 
tokens.getTypes().length}));
+          }
 
-        // Keep track of where the last span was so when there are 
multiple/duplicate
-        // spans we can get the next one instead of the first one each time.
-        int characterStart = 0;
+          // The outputs from the model.
+          final float[][][] v = (float[][][]) 
session.run(inputs).get(0).getValue();
 
-        // We are looping over the vector for each word,
-        // finding the index of the array that has the maximum value,
-        // and then finding the token classification that corresponds to that 
index.
-        for (int x = 0; x < v[0].length; x++) {
+          // Find consecutive B-PER and I-PER labels and combine the spans 
where necessary.
+          // There are also B-LOC and I-LOC tags for locations that might be 
useful at some point.
 
-          final float[] arr = v[0][x];
-          final int maxIndex = maxIndex(arr);
-          final String label = ids2Labels.get(maxIndex);
+          // Keep track of where the last span was so when there are 
multiple/duplicate
+          // spans we can get the next one instead of the first one each time.
+          int characterStart = 0;
 
-          // TODO: Need to make sure this value is between 0 and 1?
-          // Can we do thresholding without it between 0 and 1?
-          final double confidence = arr[maxIndex]; // / 10;
+          // We are looping over the vector for each word,
+          // finding the index of the array that has the maximum value,
+          // and then finding the token classification that corresponds to 
that index.
+          for (int x = 0; x < v[0].length; x++) {
 
-          // Show each token and its label per the model.
-          // System.out.println(tokens.getTokens()[x] + " : " + label);
+            final float[] arr = v[0][x];
+            final int maxIndex = maxIndex(arr);
+            final String label = ids2Labels.get(maxIndex);
 
-          // Is this is the start of a person entity.
-          if (B_PER.equals(label)) {
+            // TODO: Need to make sure this value is between 0 and 1?
+            // Can we do thresholding without it between 0 and 1?
+            final double confidence = arr[maxIndex]; // / 10;
 
-            final String spanText;
+            // Is this is the start of a person entity.
+            if (B_PER.equals(label)) {
 
-            // Find the end index of the span in the array (where the label is 
not I-PER).
-            final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels, 
tokens.getTokens());
+              String spanText;
 
-            // If the end is -1 it means this is a single-span token.
-            // If the end is != -1 it means this is a multi-span token.
-            if (spanEnd.getIndex() != -1) {
+              // Find the end index of the span in the array (where the label 
is not I-PER).
+              final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels, 
tokens.getTokens());
 
-              final StringBuilder sb = new StringBuilder();
+              // If the end is -1 it means this is a single-span token.
+              // If the end is != -1 it means this is a multi-span token.
+              if (spanEnd.getIndex() != -1) {
 
-              // We have to concatenate the tokens.
-              // Add each token in the array and separate them with a space.
-              // We'll separate each with a single space because later we'll 
find the original span
-              // in the text and ignore spacing between individual tokens in 
findByRegex().
-              int end = spanEnd.getIndex();
-              for (int i = x; i <= end; i++) {
+                final StringBuilder sb = new StringBuilder();
 
-                // If the next token starts with ##, combine it with this 
token.
-                if (tokens.getTokens()[i + 1].startsWith("##")) {
+                // We have to concatenate the tokens.
+                // Add each token in the array and separate them with a space.
+                // We'll separate each with a single space because later we'll 
find the original span
+                // in the text and ignore spacing between individual tokens in 
findByRegex().
+                int end = spanEnd.getIndex();
+                for (int i = x; i <= end; i++) {
 
-                  sb.append(tokens.getTokens()[i] + tokens.getTokens()[i + 
1].replaceAll("##", ""));
+                  // If the next token starts with ##, combine it with this 
token.
+                  if (tokens.getTokens()[i + 1].startsWith("##")) {
 
-                  // Append a space unless the next (next) token starts with 
##.
-                  if (!tokens.getTokens()[i + 2].startsWith("##")) {
-                    sb.append(" ");
-                  }
+                    sb.append(tokens.getTokens()[i] + tokens.getTokens()[i + 
1].replaceAll("##", ""));
+
+                    // Append a space unless the next (next) token starts with 
##.
+                    if (!tokens.getTokens()[i + 2].startsWith("##")) {
+                      sb.append(" ");
+                    }
+
+                    // Skip the next token since we just included it in this 
iteration.
+                    i++;
 
-                  // Skip the next token since we just included it in this 
iteration.
-                  i++;
+                  } else {
 
-                } else {
+                    sb.append(tokens.getTokens()[i].replaceAll("##", ""));
 
-                  sb.append(tokens.getTokens()[i].replaceAll("##", ""));
+                    // Append a space unless the next token is a period.
+                    if (!".".equals(tokens.getTokens()[i + 1])) {
+                      sb.append(" ");
+                    }
 
-                  // Append a space unless the next token is a period.
-                  if (!".".equals(tokens.getTokens()[i + 1])) {
-                    sb.append(" ");
                   }
 
                 }
 
+                // This is the text of the span. We use the whole original 
input text and not one
+                // of the splits. This gives us accurate character positions.
+                spanText = findByRegex(text, sb.toString().trim()).trim();
+
+              } else {
+
+                // This is a single-token span so there is nothing else to do 
except grab the token.
+                spanText = tokens.getTokens()[x];
+
               }
 
-              // This is the text of the span. We use the whole original input 
text and not one
-              // of the splits. This gives us accurate character positions.
-              spanText = findByRegex(text, sb.toString().trim()).trim();
+              if (!SEPARATOR.equals(spanText)) {
 
-            } else {
+                spanText = spanText.replaceAll("##", "");
 
-              // This is a single-token span so there is nothing else to do 
except grab the token.
-              spanText = tokens.getTokens()[x];
+                // This ignores other potential matches in the same sentence
+                // by only taking the first occurrence.
+                characterStart = text.indexOf(spanText, characterStart);
 
-            }
+                // TODO: This check should not be needed because the span was 
found.
+                // If we aren't finding it now it's because there's a 
whitespace difference.
+                if (characterStart != -1) {
+
+                  final int characterEnd = characterStart + spanText.length();
 
-            // This ignores other potential matches in the same sentence
-            // by only taking the first occurrence.
-            characterStart = text.indexOf(spanText, characterStart);
-            final int characterEnd = characterStart + spanText.length();
+                  spans.add(new Span(characterStart, characterEnd, spanText, 
confidence));
 
-            spans.add(new Span(characterStart, characterEnd, spanText, 
confidence));
+                  // OP-1: Only increment characterStart by one.
+                  characterStart++;
 
-            characterStart = characterEnd;
+                }
+
+              }
+
+            }
 
           }
 
+        } catch (OrtException ex) {
+          throw new RuntimeException("Error performing namefinder inference: " 
+ ex.getMessage(), ex);
         }
 
-      } catch (OrtException ex) {
-        throw new RuntimeException("Error performing namefinder inference: " + 
ex.getMessage(), ex);
       }
 
     }
diff --git 
a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java 
b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
index 26beab04..14e4c67b 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
@@ -18,6 +18,7 @@
 package opennlp.dl.namefinder;
 
 import java.io.File;
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -27,10 +28,18 @@ import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import opennlp.dl.AbstactDLTest;
+import opennlp.tools.sentdetect.SentenceDetector;
+import opennlp.tools.sentdetect.SentenceDetectorME;
 import opennlp.tools.util.Span;
 
 public class NameFinderDLEval extends AbstactDLTest {
 
+  private final SentenceDetector sentenceDetector ;
+
+  public NameFinderDLEval() throws IOException {
+    this.sentenceDetector = new SentenceDetectorME("en");
+  }
+
   @Test
   public void tokenNameFinder1Test() throws Exception {
 
@@ -43,7 +52,7 @@ public class NameFinderDLEval extends AbstactDLTest {
     final String[] tokens = new String[]
         {"George", "Washington", "was", "president", "of", "the", "United", 
"States", "."};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     for (Span span : spans) {
@@ -69,7 +78,7 @@ public class NameFinderDLEval extends AbstactDLTest {
 
     final String[] tokens = new String[]{"His", "name", "was", "George", 
"Washington"};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     for (Span span : spans) {
@@ -93,7 +102,7 @@ public class NameFinderDLEval extends AbstactDLTest {
 
     final String[] tokens = new String[]{"His", "name", "was", "George"};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     for (Span span : spans) {
@@ -117,7 +126,7 @@ public class NameFinderDLEval extends AbstactDLTest {
 
     final String[] tokens = new String[]{};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     Assertions.assertEquals(0, spans.length);
@@ -135,7 +144,7 @@ public class NameFinderDLEval extends AbstactDLTest {
 
     final String[] tokens = new String[]{"I", "went", "to", "the", "park"};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     Assertions.assertEquals(0, spans.length);
@@ -154,7 +163,7 @@ public class NameFinderDLEval extends AbstactDLTest {
     final String[] tokens = new String[]{"George", "Washington", "and", 
"Abraham", "Lincoln",
         "were", "presidents"};
 
-    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels());
+    final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab, 
getIds2Labels(), sentenceDetector);
     final Span[] spans = nameFinderDL.find(tokens);
 
     for (Span span : spans) {
@@ -179,7 +188,7 @@ public class NameFinderDLEval extends AbstactDLTest {
       final File model = new File("invalid.onnx");
       final File vocab = new File("vocab.txt");
 
-      new NameFinderDL(model, vocab, getIds2Labels());
+      new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector);
     });
 
   }

Reply via email to