krickert commented on code in PR #1086:
URL: https://github.com/apache/opennlp/pull/1086#discussion_r3434441240


##########
opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java:
##########
@@ -144,246 +152,334 @@ private static InferenceOptions 
validateConstructorArguments(
   @Override
   public Span[] find(String[] input) {
 
-    final List<Span> spans = new LinkedList<>();
+    final List<Span> spans = new ArrayList<>();
 
     // Join the tokens here because they will be tokenized using Wordpiece 
during inference.
     final String text = String.join(" ", input);
 
-    final String[] sentences = sentenceDetector.sentDetect(text);
+    // sentPosDetect (not sentDetect) so each sentence's offset in the full 
text is known.
+    final Span[] sentenceSpans = sentenceDetector.sentPosDetect(text);
+
+    for (final Span sentenceSpan : sentenceSpans) {
 
-    for (String sentence : sentences) {
+      // Floor the character cursor at this sentence's start, then thread it 
forward across the
+      // sentence's chunks so a repeated surface form is located at its next 
occurrence. Flooring
+      // per sentence keeps an entity from being matched against an identical 
surface form in an
+      // earlier sentence -- even one that produced no spans, which would 
otherwise leave the
+      // cursor behind and mis-locate the match.
+      int searchStart = sentenceSpan.getStart();
 
       // The WordPiece tokenized text. This changes the spacing in the text.
-      final List<Tokens> wordpieceTokens = tokenize(sentence);
+      final List<Tokens> wordpieceTokens = 
tokenize(sentenceSpan.getCoveredText(text).toString());
 
       for (final Tokens tokens : wordpieceTokens) {
+        final List<Span> decoded =
+            decodeSpans(text, tokens.tokens(), infer(tokens), ids2Labels, 
searchStart);
+        spans.addAll(decoded);
+        if (!decoded.isEmpty()) {
+          searchStart = decoded.get(decoded.size() - 1).getEnd();
+        }
+      }
 
-        try {
-
-          // The inputs to the ONNX model.
-          final Map<String, OnnxTensor> inputs = new HashMap<>();
-
-          final float[][][] v;
-          try {
-            inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, 
LongBuffer.wrap(tokens.ids()),
-                new long[] {1, tokens.ids().length}));
-
-            if (includeAttentionMask) {
-              inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
-                  LongBuffer.wrap(tokens.mask()), new long[] {1, 
tokens.mask().length}));
-            }
-
-            if (includeTokenTypeIds) {
-              inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
-                  LongBuffer.wrap(tokens.types()), new long[] {1, 
tokens.types().length}));
-            }
-
-            // The outputs from the model.
-            try (OrtSession.Result result = session.run(inputs)) {
-              // getValue() copies the tensor into Java arrays, so the result 
can be closed safely.
-              v = (float[][][]) result.get(0).getValue();
-            }
-          } finally {
-            inputs.values().forEach(OnnxTensor::close);
-          }
-
-          // Find consecutive B-PER and I-PER labels and combine the spans 
where necessary.
-          // There are also B-LOC and I-LOC tags for locations that might be 
useful at some point.
+    }
 
-          // Keep track of where the last span was so when there are 
multiple/duplicate
-          // spans we can get the next one instead of the first one each time.
-          int characterStart = 0;
+    return spans.toArray(new Span[0]);
 
-          final String[] toks = tokens.tokens();
+  }
 
-          // We are looping over the vector for each word,
-          // finding the index of the array that has the maximum value,
-          // and then finding the token classification that corresponds to 
that index.
-          for (int x = 0; x < v[0].length; x++) {
+  /**
+   * Runs the model on one token window and returns the per-token label score 
rows. A failure
+   * executing the model (an {@link OrtException} or any runtime fault) is 
surfaced as an
+   * {@link IllegalStateException} (cause preserved); an unexpected output 
shape is its own loud
+   * failure. This mirrors the fail-loud contract of the sibling {@code 
DocumentCategorizerDL}.
+   *
+   * @param tokens The tokens for one chunk to run inference on.
+   * @return The {@code [token][label]} score matrix for the chunk.
+   */
+  private float[][] infer(final Tokens tokens) {
 
-            final float[] arr = v[0][x];
-            final int maxIndex = maxIndex(arr);
-            final String label = ids2Labels.get(maxIndex);
+    final Map<String, OnnxTensor> inputs = new HashMap<>();
+    final Object output;
+    try {
+      inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, 
LongBuffer.wrap(tokens.ids()),
+          new long[] {1, tokens.ids().length}));
 
-            // TODO: Need to make sure this value is between 0 and 1?
-            // Can we do thresholding without it between 0 and 1?
-            final double confidence = arr[maxIndex]; // / 10;
+      if (includeAttentionMask) {
+        inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+            LongBuffer.wrap(tokens.mask()), new long[] {1, 
tokens.mask().length}));
+      }
 
-            // Is this is the start of a person entity.
-            if (B_PER.equals(label)) {
+      if (includeTokenTypeIds) {
+        inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+            LongBuffer.wrap(tokens.types()), new long[] {1, 
tokens.types().length}));
+      }
 
-              String spanText;
+      // getValue() copies the tensor into Java arrays, so the result can be 
closed safely.
+      try (OrtSession.Result result = session.run(inputs)) {
+        output = result.get(0).getValue();
+      }
+    } catch (OrtException | RuntimeException ex) {
+      throw new IllegalStateException("Unable to perform name finder 
inference", ex);

Review Comment:
   Done. 
   
   - `OrtException` and `RuntimeException` are now caught separately with 
distinct messages (an Ort failure vs. an unexpected runtime fault), each 
wrapped in `IllegalStateException` with the underlying cause appended to the 
message and preserved as the chained cause. 
   - `find()` now has Javadoc using `{@inheritDoc}` that documents both 
exceptions it can surface:
     -  `IllegalStateException` (inference failure / unexpected output shape / 
unmapped label index) and `IllegalArgumentException` (a token absent from the 
vocabulary, i.e. vocab/model mismatch).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to