This is an automated email from the ASF dual-hosted git repository.
jzemerick pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/master by this push:
new 02eaffa4 OPENNLP-1407: Adding sentence detector to NameFinderDL. (#447)
02eaffa4 is described below
commit 02eaffa4cfb0f9c2c2ea1da864f8cc53d215e97d
Author: Jeff Zemerick <[email protected]>
AuthorDate: Sat Dec 10 09:18:17 2022 -0500
OPENNLP-1407: Adding sentence detector to NameFinderDL. (#447)
---
.../java/opennlp/dl/namefinder/NameFinderDL.java | 194 +++++++++++----------
.../opennlp/dl/namefinder/NameFinderDLEval.java | 23 ++-
2 files changed, 121 insertions(+), 96 deletions(-)
diff --git a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index ff132897..319a2074 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -39,6 +39,7 @@ import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;
@@ -54,23 +55,27 @@ public class NameFinderDL implements TokenNameFinder {
public static final String I_PER = "I-PER";
public static final String B_PER = "B-PER";
+ public static final String SEPARATOR = "[SEP]";
protected final OrtSession session;
+ private final SentenceDetector sentenceDetector;
private final Map<Integer, String> ids2Labels;
private final Tokenizer tokenizer;
private final Map<String, Integer> vocab;
private final InferenceOptions inferenceOptions;
protected final OrtEnvironment env;
- public NameFinderDL(File model, File vocabulary, Map<Integer, String>
ids2Labels) throws Exception {
+ public NameFinderDL(File model, File vocabulary, Map<Integer, String>
ids2Labels,
+ SentenceDetector sentenceDetector) throws Exception {
- this(model, vocabulary, ids2Labels, new InferenceOptions());
+ this(model, vocabulary, ids2Labels, new InferenceOptions(),
sentenceDetector);
}
public NameFinderDL(File model, File vocabulary, Map<Integer, String>
ids2Labels,
- InferenceOptions inferenceOptions) throws Exception {
+ InferenceOptions inferenceOptions,
+ SentenceDetector sentenceDetector) throws Exception {
this.env = OrtEnvironment.getEnvironment();
@@ -84,145 +89,156 @@ public class NameFinderDL implements TokenNameFinder {
this.vocab = loadVocab(vocabulary);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.inferenceOptions = inferenceOptions;
+ this.sentenceDetector = sentenceDetector;
}
@Override
public Span[] find(String[] input) {
-
-
- /**
- * So, it looks like inference is being done on the wordpiece tokens but
then
- * spans are being created from the whitespace tokens.
- */
-
final List<Span> spans = new LinkedList<>();
// Join the tokens here because they will be tokenized using Wordpiece
during inference.
final String text = String.join(" ", input);
- // The WordPiece tokenized text. This changes the spacing in the text.
- final List<Tokens> wordpieceTokens = tokenize(text);
+ final String[] sentences = sentenceDetector.sentDetect(text);
- for (final Tokens tokens : wordpieceTokens) {
+ for (String sentence : sentences) {
- try {
+ // The WordPiece tokenized text. This changes the spacing in the text.
+ final List<Tokens> wordpieceTokens = tokenize(sentence);
- // The inputs to the ONNX model.
- final Map<String, OnnxTensor> inputs = new HashMap<>();
- inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.getIds()),
- new long[] {1, tokens.getIds().length}));
+ for (final Tokens tokens : wordpieceTokens) {
- if (inferenceOptions.isIncludeAttentionMask()) {
- inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
- LongBuffer.wrap(tokens.getMask()), new long[] {1,
tokens.getMask().length}));
- }
+ try {
- if (inferenceOptions.isIncludeTokenTypeIds()) {
- inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
- LongBuffer.wrap(tokens.getTypes()), new long[] {1,
tokens.getTypes().length}));
- }
+ // The inputs to the ONNX model.
+ final Map<String, OnnxTensor> inputs = new HashMap<>();
+ inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.getIds()),
+ new long[] {1, tokens.getIds().length}));
- // The outputs from the model.
- final float[][][] v = (float[][][])
session.run(inputs).get(0).getValue();
+ if (inferenceOptions.isIncludeAttentionMask()) {
+ inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.getMask()), new long[] {1,
tokens.getMask().length}));
+ }
- // Find consecutive B-PER and I-PER labels and combine the spans where
necessary.
- // There are also B-LOC and I-LOC tags for locations that might be
useful at some point.
+ if (inferenceOptions.isIncludeTokenTypeIds()) {
+ inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.getTypes()), new long[] {1,
tokens.getTypes().length}));
+ }
- // Keep track of where the last span was so when there are
multiple/duplicate
- // spans we can get the next one instead of the first one each time.
- int characterStart = 0;
+ // The outputs from the model.
+ final float[][][] v = (float[][][])
session.run(inputs).get(0).getValue();
- // We are looping over the vector for each word,
- // finding the index of the array that has the maximum value,
- // and then finding the token classification that corresponds to that
index.
- for (int x = 0; x < v[0].length; x++) {
+ // Find consecutive B-PER and I-PER labels and combine the spans
where necessary.
+ // There are also B-LOC and I-LOC tags for locations that might be
useful at some point.
- final float[] arr = v[0][x];
- final int maxIndex = maxIndex(arr);
- final String label = ids2Labels.get(maxIndex);
+ // Keep track of where the last span was so when there are
multiple/duplicate
+ // spans we can get the next one instead of the first one each time.
+ int characterStart = 0;
- // TODO: Need to make sure this value is between 0 and 1?
- // Can we do thresholding without it between 0 and 1?
- final double confidence = arr[maxIndex]; // / 10;
+ // We are looping over the vector for each word,
+ // finding the index of the array that has the maximum value,
+ // and then finding the token classification that corresponds to
that index.
+ for (int x = 0; x < v[0].length; x++) {
- // Show each token and its label per the model.
- // System.out.println(tokens.getTokens()[x] + " : " + label);
+ final float[] arr = v[0][x];
+ final int maxIndex = maxIndex(arr);
+ final String label = ids2Labels.get(maxIndex);
- // Is this is the start of a person entity.
- if (B_PER.equals(label)) {
+ // TODO: Need to make sure this value is between 0 and 1?
+ // Can we do thresholding without it between 0 and 1?
+ final double confidence = arr[maxIndex]; // / 10;
- final String spanText;
+ // Is this is the start of a person entity.
+ if (B_PER.equals(label)) {
- // Find the end index of the span in the array (where the label is
not I-PER).
- final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels,
tokens.getTokens());
+ String spanText;
- // If the end is -1 it means this is a single-span token.
- // If the end is != -1 it means this is a multi-span token.
- if (spanEnd.getIndex() != -1) {
+ // Find the end index of the span in the array (where the label
is not I-PER).
+ final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels,
tokens.getTokens());
- final StringBuilder sb = new StringBuilder();
+ // If the end is -1 it means this is a single-span token.
+ // If the end is != -1 it means this is a multi-span token.
+ if (spanEnd.getIndex() != -1) {
- // We have to concatenate the tokens.
- // Add each token in the array and separate them with a space.
- // We'll separate each with a single space because later we'll
find the original span
- // in the text and ignore spacing between individual tokens in
findByRegex().
- int end = spanEnd.getIndex();
- for (int i = x; i <= end; i++) {
+ final StringBuilder sb = new StringBuilder();
- // If the next token starts with ##, combine it with this
token.
- if (tokens.getTokens()[i + 1].startsWith("##")) {
+ // We have to concatenate the tokens.
+ // Add each token in the array and separate them with a space.
+ // We'll separate each with a single space because later we'll
find the original span
+ // in the text and ignore spacing between individual tokens in
findByRegex().
+ int end = spanEnd.getIndex();
+ for (int i = x; i <= end; i++) {
- sb.append(tokens.getTokens()[i] + tokens.getTokens()[i +
1].replaceAll("##", ""));
+ // If the next token starts with ##, combine it with this
token.
+ if (tokens.getTokens()[i + 1].startsWith("##")) {
- // Append a space unless the next (next) token starts with
##.
- if (!tokens.getTokens()[i + 2].startsWith("##")) {
- sb.append(" ");
- }
+ sb.append(tokens.getTokens()[i] + tokens.getTokens()[i +
1].replaceAll("##", ""));
+
+ // Append a space unless the next (next) token starts with
##.
+ if (!tokens.getTokens()[i + 2].startsWith("##")) {
+ sb.append(" ");
+ }
+
+ // Skip the next token since we just included it in this
iteration.
+ i++;
- // Skip the next token since we just included it in this
iteration.
- i++;
+ } else {
- } else {
+ sb.append(tokens.getTokens()[i].replaceAll("##", ""));
- sb.append(tokens.getTokens()[i].replaceAll("##", ""));
+ // Append a space unless the next token is a period.
+ if (!".".equals(tokens.getTokens()[i + 1])) {
+ sb.append(" ");
+ }
- // Append a space unless the next token is a period.
- if (!".".equals(tokens.getTokens()[i + 1])) {
- sb.append(" ");
}
}
+ // This is the text of the span. We use the whole original
input text and not one
+ // of the splits. This gives us accurate character positions.
+ spanText = findByRegex(text, sb.toString().trim()).trim();
+
+ } else {
+
+ // This is a single-token span so there is nothing else to do
except grab the token.
+ spanText = tokens.getTokens()[x];
+
}
- // This is the text of the span. We use the whole original input
text and not one
- // of the splits. This gives us accurate character positions.
- spanText = findByRegex(text, sb.toString().trim()).trim();
+ if (!SEPARATOR.equals(spanText)) {
- } else {
+ spanText = spanText.replaceAll("##", "");
- // This is a single-token span so there is nothing else to do
except grab the token.
- spanText = tokens.getTokens()[x];
+ // This ignores other potential matches in the same sentence
+ // by only taking the first occurrence.
+ characterStart = text.indexOf(spanText, characterStart);
- }
+ // TODO: This check should not be needed because the span was
found.
+ // If we aren't finding it now it's because there's a
whitespace difference.
+ if (characterStart != -1) {
+
+ final int characterEnd = characterStart + spanText.length();
- // This ignores other potential matches in the same sentence
- // by only taking the first occurrence.
- characterStart = text.indexOf(spanText, characterStart);
- final int characterEnd = characterStart + spanText.length();
+ spans.add(new Span(characterStart, characterEnd, spanText,
confidence));
- spans.add(new Span(characterStart, characterEnd, spanText,
confidence));
+ // OP-1: Only increment characterStart by one.
+ characterStart++;
- characterStart = characterEnd;
+ }
+
+ }
+
+ }
}
+ } catch (OrtException ex) {
+ throw new RuntimeException("Error performing namefinder inference: "
+ ex.getMessage(), ex);
}
- } catch (OrtException ex) {
- throw new RuntimeException("Error performing namefinder inference: " +
ex.getMessage(), ex);
}
}
diff --git
a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
index 26beab04..14e4c67b 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
@@ -18,6 +18,7 @@
package opennlp.dl.namefinder;
import java.io.File;
+import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
@@ -27,10 +28,18 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import opennlp.dl.AbstactDLTest;
+import opennlp.tools.sentdetect.SentenceDetector;
+import opennlp.tools.sentdetect.SentenceDetectorME;
import opennlp.tools.util.Span;
public class NameFinderDLEval extends AbstactDLTest {
+ private final SentenceDetector sentenceDetector ;
+
+ public NameFinderDLEval() throws IOException {
+ this.sentenceDetector = new SentenceDetectorME("en");
+ }
+
@Test
public void tokenNameFinder1Test() throws Exception {
@@ -43,7 +52,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]
{"George", "Washington", "was", "president", "of", "the", "United",
"States", "."};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
for (Span span : spans) {
@@ -69,7 +78,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]{"His", "name", "was", "George",
"Washington"};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
for (Span span : spans) {
@@ -93,7 +102,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]{"His", "name", "was", "George"};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
for (Span span : spans) {
@@ -117,7 +126,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]{};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
Assertions.assertEquals(0, spans.length);
@@ -135,7 +144,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]{"I", "went", "to", "the", "park"};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
Assertions.assertEquals(0, spans.length);
@@ -154,7 +163,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final String[] tokens = new String[]{"George", "Washington", "and",
"Abraham", "Lincoln",
"were", "presidents"};
- final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels());
+ final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(), sentenceDetector);
final Span[] spans = nameFinderDL.find(tokens);
for (Span span : spans) {
@@ -179,7 +188,7 @@ public class NameFinderDLEval extends AbstactDLTest {
final File model = new File("invalid.onnx");
final File vocab = new File("vocab.txt");
- new NameFinderDL(model, vocab, getIds2Labels());
+ new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector);
});
}