kinow commented on code in PR #422:
URL: https://github.com/apache/opennlp/pull/422#discussion_r941276459


##########
opennlp-dl/src/main/java/opennlp/dl/doccat/scoring/AverageClassifcationScoringStrategy.java:
##########
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.doccat.scoring;
+
+import java.util.List;
+
+/**
+ * Calculates the document classification scores by averaging the scores for
+ * all individual parts of a document.
+ */
+public class AverageClassifcationScoringStrategy implements 
ClassificationScoringStrategy {
+
+  @Override
+  public double[] score(List<double[]> scores) {
+
+    final int values = scores.get(0).length;
+
+    final double[] averages = new double[values];
+
+    int j = 0;
+
+    for (int i = 0; i < values; i++) {
+
+      double sum = 0;
+
+      for (final double[] score : scores) {
+
+        sum += score[i];
+
+      }
+
+      averages[j++] = (sum / scores.size());
+
+    }
+
+    return averages;

Review Comment:
   For one second I thought this was simply calculating the average of arrays, 
and was going to suggest to use streams and other Java 8+ methods… then I 
realized it was doing something more elaborate. :+1: no need to change anything 
here then! :+1: 



##########
opennlp-dl/src/main/java/opennlp/dl/doccat/scoring/AverageClassifcationScoringStrategy.java:
##########
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.doccat.scoring;
+
+import java.util.List;
+
+/**
+ * Calculates the document classification scores by averaging the scores for
+ * all individual parts of a document.
+ */
+public class AverageClassifcationScoringStrategy implements 
ClassificationScoringStrategy {

Review Comment:
   Typo in class name?
   
   s/AverageClassifcationScoringStrategy/AverageClassificationScoringStrategy



##########
opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java:
##########
@@ -17,140 +17,214 @@
 
 package opennlp.dl.namefinder;
 
+import java.io.BufferedReader;
 import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.nio.LongBuffer;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
 
-import opennlp.dl.Inference;
 import opennlp.dl.InferenceOptions;
+import opennlp.dl.SpanEnd;
+import opennlp.dl.Tokens;
 import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
 import opennlp.tools.util.Span;
 
 /**
  * An implementation of {@link TokenNameFinder} that uses ONNX models.
  */
 public class NameFinderDL implements TokenNameFinder {
 
+  public static final String INPUT_IDS = "input_ids";
+  public static final String ATTENTION_MASK = "attention_mask";
+  public static final String TOKEN_TYPE_IDS = "token_type_ids";
+
   public static final String I_PER = "I-PER";
   public static final String B_PER = "B-PER";
 
-  private final Inference inference;
+  protected final OrtSession session;
+
   private final Map<Integer, String> ids2Labels;
+  private final Tokenizer tokenizer;
+  private final Map<String, Integer> vocab;
+  private final InferenceOptions inferenceOptions;
+  protected final OrtEnvironment env;
 
-  /**
-   * Creates a new NameFinderDL for entity recognition using ONNX models.
-   *
-   * @param model     The ONNX model file.
-   * @param vocab     The model's vocabulary file.
-   * @param ids2Labels  A map of values and their assigned labels used to 
train the model.
-   * @throws Exception Thrown if the models cannot be loaded.
-   */
-  public NameFinderDL(File model, File vocab, Map<Integer, String> ids2Labels)
-          throws Exception {
+  public NameFinderDL(File model, File vocabulary, Map<Integer, String> 
ids2Labels) throws Exception {
 
-    this.ids2Labels = ids2Labels;
-    this.inference = new NameFinderInference(model, vocab, new 
InferenceOptions());
+    this(model, vocabulary, ids2Labels, new InferenceOptions());
 
   }
 
-  /**
-   * Creates a new NameFinderDL for entity recognition using ONNX models.
-   *
-   * @param model     The ONNX model file.
-   * @param vocab     The model's vocabulary file.
-   * @param ids2Labels  A map of values and their assigned labels used to 
train the model.
-   * @param inferenceOptions The {@link InferenceOptions} used to customize 
the inference process.
-   * @throws Exception Thrown if the models cannot be loaded.
-   */
-  public NameFinderDL(File model, File vocab, Map<Integer, String> ids2Labels,
+  public NameFinderDL(File model, File vocabulary, Map<Integer, String> 
ids2Labels,
                       InferenceOptions inferenceOptions) throws Exception {
 
-    this.ids2Labels = ids2Labels;
-    this.inference = new NameFinderInference(model, vocab, inferenceOptions);
-
-  }
+    this.env = OrtEnvironment.getEnvironment();
 
-  /**
-   * Creates a new NameFinderDL for entity recognition using ONNX models.
-   *
-   * @param ids2Labels  A map of values and their assigned labels used to 
train the model.
-   * @param inference A custom implementation of {@link Inference}.
-   * @throws Exception Thrown if the models cannot be loaded.
-   */
-  public NameFinderDL(Map<Integer, String> ids2Labels,
-                      Inference inference) throws Exception {
+    final OrtSession.SessionOptions sessionOptions = new 
OrtSession.SessionOptions();
+    if (inferenceOptions.isGpu()) {
+      sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
+    }
 
+    this.session = env.createSession(model.getPath(), sessionOptions);
     this.ids2Labels = ids2Labels;
-    this.inference = inference;
+    this.vocab = loadVocab(vocabulary);
+    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.inferenceOptions = inferenceOptions;
 
   }
 
   @Override
-  public Span[] find(String[] tokens) {
+  public Span[] find(String[] input) {
+
+
+
+    /**
+     * So, it looks like inference is being done on the wordpiece tokens but 
then
+     * spans are being created from the whitespace tokens.
+     */
 
     final List<Span> spans = new LinkedList<>();
-    final String text = String.join(" ", tokens);
 
-    try {
+    // Join the tokens here because they will be tokenized using Wordpiece 
during inference.
+    final String text = String.join(" ", input);
+
+    // The WordPiece tokenized text. This changes the spacing in the text.
+    final List<Tokens> wordpieceTokens = tokenize(text);
+
+    for (final Tokens tokens : wordpieceTokens) {
+
+      try {
+
+        // The inputs to the ONNX model.
+        final Map<String, OnnxTensor> inputs = new HashMap<>();
+        inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, 
LongBuffer.wrap(tokens.getIds()),
+            new long[] {1, tokens.getIds().length}));
+
+        if (inferenceOptions.isIncludeAttentionMask()) {
+          inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+              LongBuffer.wrap(tokens.getMask()), new long[] {1, 
tokens.getMask().length}));
+        }
+
+        if (inferenceOptions.isIncludeTokenTypeIds()) {
+          inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+              LongBuffer.wrap(tokens.getTypes()), new long[] {1, 
tokens.getTypes().length}));
+        }
+
+        // The outputs from the model.
+        final float[][][] v = (float[][][]) 
session.run(inputs).get(0).getValue();
+
+        // Find consecutive B-PER and I-PER labels and combine the spans where 
necessary.
+        // There are also B-LOC and I-LOC tags for locations that might be 
useful at some point.
+
+        // Keep track of where the last span was so when there are 
multiple/duplicate
+        // spans we can get the next one instead of the first one each time.
+        int characterStart = 0;
+
+        // We are looping over the vector for each word,
+        // finding the index of the array that has the maximum value,
+        // and then finding the token classification that corresponds to that 
index.
+        for (int x = 0; x < v[0].length; x++) {
+
+          final float[] arr = v[0][x];
+          final int maxIndex = maxIndex(arr);
+          final String label = ids2Labels.get(maxIndex);
+
+          // TODO: Need to make sure this value is between 0 and 1?
+          // Can we do thresholding without it between 0 and 1?
+          final double confidence = arr[maxIndex]; // / 10;
+
+          // Show each token and its label per the model.
+          // System.out.println(tokens.getTokens()[x] + " : " + label);
+
+          // Is this is the start of a person entity.
+          if (B_PER.equals(label)) {
+
+            final String spanText;
+
+            // Find the end index of the span in the array (where the label is 
not I-PER).
+            final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels, 
tokens.getTokens());
+
+            // If the end is -1 it means this is a single-span token.
+            // If the end is != -1 it means this is a multi-span token.
+            if (spanEnd.getIndex() != -1) {
 
-      final float[][][] vectors = (float[][][]) inference.infer(text);
-      final double[][] v = inference.convertFloatsToDoubles(vectors[0]);
+              final StringBuilder sb = new StringBuilder();
 
-      // Find consecutive B-PER and I-PER labels and combine the spans where 
necessary.
-      // There are also B-LOC and I-LOC tags for locations that might be 
useful at some point.
+              // We have to concatenate the tokens.
+              // Add each token in the array and separate them with a space.
+              // We'll separate each with a single space because later we'll 
find the original span
+              // in the text and ignore spacing between individual tokens in 
findByRegex().
+              int end = spanEnd.getIndex();
+              for (int i = x; i <= end; i++) {
 
-      // Keep track of where the last span was so when there are 
multiple/duplicate
-      // spans we can get the next one instead of the first one each time.
-      int characterStart = 0;
+                // If the next token starts with ##, combine it with this 
token.
+                if (tokens.getTokens()[i + 1].startsWith("##")) {
 
-      // We are looping over the vector for each word,
-      // finding the index of the array that has the maximum value,
-      // and then finding the token classification that corresponds to that 
index.
-      for (int x = 0; x < v.length; x++) {
+                  sb.append(tokens.getTokens()[i] + tokens.getTokens()[i + 
1].replaceAll("##", ""));
 
-        final double[] arr = v[x];
-        final int maxIndex = Inference.maxIndex(arr);
-        final String label = ids2Labels.get(maxIndex);
+                  // Append a space unless the next (next) token starts with 
##.
+                  if (!tokens.getTokens()[i + 2].startsWith("##")) {
+                    sb.append(" ");
+                  }
 
-        final double probability = arr[maxIndex];
+                  // Skip the next token since we just included it in this 
iteration.
+                  i++;
 
-        if (B_PER.equalsIgnoreCase(label)) {
+                } else {
 
-          // This is the start of a person entity.
-          final String spanText;
+                  sb.append(tokens.getTokens()[i].replaceAll("##", ""));
 
-          // Find the end index of the span in the array (where the label is 
not I-PER).
-          final int endIndex = findSpanEnd(v, x, ids2Labels);
+                  // Append a space unless the next token is a period.
+                  if (!".".equals(tokens.getTokens()[i + 1])) {
+                    sb.append(" ");
+                  }
 
-          // If the end is -1 it means this is a single-span token.
-          // If the end is != -1 it means this is a multi-span token.
-          if (endIndex != -1) {
+                }
 
-            // Subtract one for the beginning token not part of the text.
-            spanText = String.join(" ", Arrays.copyOfRange(tokens, x - 1, 
endIndex));
+              }
 
-            spans.add(new Span(x - 1, endIndex, spanText, probability));
+              // This is the text of the span. We use the whole original input 
text and not one
+              // of the splits. This gives us accurate character positions.
+              spanText = findByRegex(text, sb.toString().trim()).trim();
 
-            x = endIndex;
+            } else {
 
-          } else {
+              // This is a single-token span so there is nothing else to do 
except grab the token.
+              spanText = tokens.getTokens()[x];
 
-            // This is a single-token span so there is nothing else to do 
except grab the token.
-            spanText = tokens[x];
+            }
 
-            // Subtract one for the beginning token not part of the text.
-            spans.add(new Span(x - 1, endIndex, spanText, probability));
+            // This ignores other potential matches in the same sentence
+            // by only taking the first occurrence.
+            characterStart = text.indexOf(spanText, characterStart);
+            final int characterEnd = characterStart + spanText.length();
+
+            spans.add(new Span(characterStart, characterEnd, spanText, 
confidence));
+
+            characterStart = characterEnd;
 
           }
 
         }
 
+      } catch (OrtException ex) {
+        throw new RuntimeException("Error performing namefinder inference: " + 
ex.getMessage(), ex);

Review Comment:
   :+1: better than the old `System.err` without the rest of the stack trace.



##########
opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java:
##########
@@ -176,20 +251,147 @@ private int findSpanEnd(double[][] v, int startIndex, 
Map<Integer, String> id2La
     for (int x = startIndex + 1; x < v[0].length; x++) {
 
       // Get the next item.
-      final double[] arr = v[x];
+      final float[] arr = v[0][x];
 
       // See if the next token has an I-PER label.
-      final String nextTokenClassification = 
id2Labels.get(Inference.maxIndex(arr));
+      final String nextTokenClassification = id2Labels.get(maxIndex(arr));
 
-      if (!I_PER.equalsIgnoreCase(nextTokenClassification)) {
+      if (!I_PER.equals(nextTokenClassification)) {
         index = x - 1;
         break;
       }
 
     }
 
+    // Find where the span ends based on the tokens.
+    for (int x = 1; x <= index && x < tokens.length; x++) {
+      characterEnd += tokens[x].length();
+    }
+
+    // Account for the number of spaces (that is the number of tokens).
+    // (One space per token.)
+    characterEnd += index - 1;
+
+    return new SpanEnd(index, characterEnd);
+
+  }
+
+  private int maxIndex(float[] arr) {
+
+    double max = Float.NEGATIVE_INFINITY;
+    int index = -1;
+
+    for (int x = 0; x < arr.length; x++) {
+      if (arr[x] > max) {
+        index = x;
+        max = arr[x];
+      }
+    }
+
     return index;
 
   }
 
+  private static String findByRegex(String text, String span) {
+
+    final String regex = span
+        .replaceAll(" ", "\\\\s+")
+        .replaceAll("\\)", "\\\\)")
+        .replaceAll("\\(", "\\\\(");
+
+    final Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
+    final Matcher matcher = pattern.matcher(text);
+
+    if (matcher.find()) {
+      return matcher.group(0);
+    }
+
+    // For some reason the regex match wasn't found. Just return the original 
span.
+    return span;
+
+  }
+
+  private List<Tokens> tokenize(final String text) {
+
+    final List<Tokens> t = new LinkedList<>();
+
+    // In this article as the paper suggests, we are going to segment the 
input into smaller text and feed
+    // each of them into BERT, it means for each row, we will split the text 
in order to have some
+    // smaller text (200 words long each)
+    // 
https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd

Review Comment:
   :eyes: :+1: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to