This is an automated email from the ASF dual-hosted git repository.

mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/main by this push:
     new 889ceab12 OPENNLP-1845: Fix numerically unstable softmax in 
DocumentCategorizerDL (#1085)
889ceab12 is described below

commit 889ceab12a932d26214cbec524d684589a370ca1
Author: Kristian Rickert <[email protected]>
AuthorDate: Tue Jun 16 05:34:31 2026 -0400

    OPENNLP-1845: Fix numerically unstable softmax in DocumentCategorizerDL 
(#1085)
    
    * Fix DocumentCategorizerDL softmax and error result handling
    
    * Fail loudly instead of 0 result
---
 .../opennlp/dl/doccat/DocumentCategorizerDL.java   | 190 ++++++++++++---------
 .../dl/doccat/DocumentCategorizerDLTest.java       | 125 ++++++++++++++
 .../dl/doccat/DocumentCategorizerDLEval.java       |  37 ++++
 3 files changed, 276 insertions(+), 76 deletions(-)

diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index cf01631bf..e357c48f2 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.nio.LongBuffer;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -37,8 +38,6 @@ import ai.onnxruntime.OnnxTensor;
 import ai.onnxruntime.OrtEnvironment;
 import ai.onnxruntime.OrtException;
 import ai.onnxruntime.OrtSession;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import opennlp.dl.AbstractDL;
 import opennlp.dl.InferenceOptions;
@@ -63,8 +62,6 @@ import opennlp.tools.doccat.DocumentCategorizer;
  */
 public class DocumentCategorizerDL extends AbstractDL implements 
DocumentCategorizer {
 
-  private static final Logger logger = 
LoggerFactory.getLogger(DocumentCategorizerDL.class);
-
   /** Classification models are commonly uncased, so lower casing is the 
default. */
   private static final boolean LOWER_CASE_DEFAULT = true;
 
@@ -72,6 +69,19 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
   private final ClassificationScoringStrategy classificationScoringStrategy;
   private final InferenceOptions inferenceOptions;
 
+  DocumentCategorizerDL(OrtEnvironment env, OrtSession session, Map<String, 
Integer> vocab,
+                        Map<Integer, String> categories,
+                        ClassificationScoringStrategy 
classificationScoringStrategy,
+                        InferenceOptions inferenceOptions) {
+    this.env = env;
+    this.session = session;
+    this.vocab = vocab;
+    this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions, 
LOWER_CASE_DEFAULT));
+    this.categories = categories;
+    this.classificationScoringStrategy = classificationScoringStrategy;
+    this.inferenceOptions = inferenceOptions;
+  }
+
   /**
    * Instantiates a {@link DocumentCategorizer document categorizer} using 
ONNX models.
    *
@@ -141,68 +151,74 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
   }
 
+  /**
+   * Categorizes the document, failing loudly rather than returning an invalid 
distribution:
+   * malformed input is rejected with {@link IllegalArgumentException}, and 
any failure executing
+   * the model is surfaced as an {@link IllegalStateException} (cause 
preserved).
+   *
+   * @param strings The document to categorize; {@code strings[0]} is 
classified.
+   * @return The per-category probabilities.
+   * @throws IllegalArgumentException If {@code strings} is {@code null} or 
empty.
+   * @throws IllegalStateException    If inference fails or the model returns 
an unexpected output.
+   */
   @Override
   public double[] categorize(String[] strings) {
 
-    try {
+    if (strings == null || strings.length == 0) {
+      throw new IllegalArgumentException("strings must contain at least one 
document to categorize");
+    }
+
+    final List<Tokens> tokens = tokenize(strings[0]);
+
+    final List<double[]> scores = new LinkedList<>();
+    for (final Tokens t : tokens) {
+      scores.add(softmax(infer(t)));
+    }
+
+    return classificationScoringStrategy.score(scores);
+  }
+
+  /**
+   * Runs the model on one token window and returns its raw per-category 
logits. A failure executing
+   * the model (an {@link OrtException} or any runtime fault) is wrapped as an
+   * {@link IllegalStateException}; an unexpected output shape is its own loud 
failure.
+   */
+  private float[] infer(final Tokens t) {
 
-      final List<Tokens> tokens = tokenize(strings[0]);
-
-      final List<double[]> scores = new LinkedList<>();
-
-      for (final Tokens t : tokens) {
-
-        final Map<String, OnnxTensor> inputs = new HashMap<>();
-
-        final Object output;
-        try {
-          inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
-              LongBuffer.wrap(t.ids()), new long[] {1, t.ids().length}));
-
-          if (inferenceOptions.isIncludeAttentionMask()) {
-            inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
-                LongBuffer.wrap(t.mask()), new long[] {1, t.mask().length}));
-          }
-
-          if (inferenceOptions.isIncludeTokenTypeIds()) {
-            inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
-                LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
-          }
-
-          // The outputs from the model. Some models return a 2D array (e.g. 
BERT),
-          // while others return a 1D array (e.g. RoBERTa).
-          try (OrtSession.Result result = session.run(inputs)) {
-            // getValue() copies the tensor into Java arrays, so the result 
can be closed safely.
-            output = result.get(0).getValue();
-          }
-        } finally {
-          inputs.values().forEach(OnnxTensor::close);
-        }
-
-        final float[] rawScores;
-        if (output instanceof float[][] v) {
-          rawScores = v[0];
-        } else if (output instanceof float[] v) {
-          rawScores = v;
-        } else {
-          throw new IllegalStateException(
-              "Unexpected model output type: " + output.getClass().getName());
-        }
-
-        // Keep track of all scores.
-        final double[] categoryScoresForTokens = softmax(rawScores);
-        scores.add(categoryScoresForTokens);
+    final Map<String, OnnxTensor> inputs = new HashMap<>();
+    final Object output;
+    try {
+      inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
+          LongBuffer.wrap(t.ids()), new long[] {1, t.ids().length}));
 
+      if (inferenceOptions.isIncludeAttentionMask()) {
+        inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+            LongBuffer.wrap(t.mask()), new long[] {1, t.mask().length}));
       }
 
-      return classificationScoringStrategy.score(scores);
+      if (inferenceOptions.isIncludeTokenTypeIds()) {
+        inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+            LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
+      }
 
-    } catch (Exception ex) {
-      logger.error("Unload to perform document classification inference", ex);
+      // getValue() copies the tensor into Java arrays, so the result can be 
closed safely.
+      try (OrtSession.Result result = session.run(inputs)) {
+        output = result.get(0).getValue();
+      }
+    } catch (OrtException | RuntimeException ex) {
+      throw new IllegalStateException("Unable to perform document 
classification inference", ex);
+    } finally {
+      inputs.values().forEach(OnnxTensor::close);
     }
 
-    return new double[] {};
-
+    // Some models return a 2D array (e.g. BERT), others a 1D array (e.g. 
RoBERTa). A different
+    // shape is a model-contract violation, surfaced on its own rather than as 
"inference failed".
+    if (output instanceof float[][] v) {
+      return v[0];
+    } else if (output instanceof float[] v) {
+      return v;
+    }
+    throw new IllegalStateException("Unexpected model output type: " + 
output.getClass().getName());
   }
 
   @Override
@@ -298,23 +314,13 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
     // Split the input text into 200 word chunks with 50 overlapping between 
chunks.
     final String[] whitespaceTokenized = text.split("\\s+");
 
-    for (int start = 0; start < whitespaceTokenized.length;
-         start = start + inferenceOptions.getDocumentSplitSize()) {
-
-      // 200 word length chunk
-      // Check the end do don't go past and get a 
StringIndexOutOfBoundsException
-      int end = start + inferenceOptions.getDocumentSplitSize();
-      if (end > whitespaceTokenized.length) {
-        end = whitespaceTokenized.length;
-      }
+    for (final int[] range : chunkRanges(whitespaceTokenized.length,
+        inferenceOptions.getDocumentSplitSize(), 
inferenceOptions.getSplitOverlapSize())) {
 
-      // The group is that subsection of string.
-      final String group = String.join(" ", 
Arrays.copyOfRange(whitespaceTokenized, start, end));
+      // The group is that subsection of the input.
+      final String group =
+          String.join(" ", Arrays.copyOfRange(whitespaceTokenized, range[0], 
range[1]));
 
-      // We want to overlap each chunk by 50 words so scoot back 50 words for 
the next iteration.
-      start = start - inferenceOptions.getSplitOverlapSize();
-
-      // Now we can tokenize the group and continue.
       final String[] tokens = tokenizer.tokenize(group);
 
       final long[] ids = tokenIds(tokens, vocab);
@@ -333,6 +339,32 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
   }
 
+  /**
+   * Computes the {@code [start, end)} word-index ranges the input is split 
into: chunks of
+   * {@code splitSize} words overlapping by {@code overlapSize}. The loop 
always advances by
+   * at least one word, so a misconfigured {@code overlapSize >= splitSize} 
can neither stall
+   * the loop nor produce negative indices.
+   *
+   * @param length The number of whitespace-separated words.
+   * @param splitSize The chunk size in words.
+   * @param overlapSize The overlap between consecutive chunks in words.
+   * @return The ordered list of {@code [start, end)} ranges; empty when 
{@code length == 0}.
+   */
+  static List<int[]> chunkRanges(final int length, final int splitSize, final 
int overlapSize) {
+    final List<int[]> ranges = new ArrayList<>();
+    int start = 0;
+    while (start < length) {
+      final int end = Math.min(start + splitSize, length);
+      ranges.add(new int[] {start, end});
+      if (end == length) {
+        break;
+      }
+      // Overlap by overlapSize words, but always move forward by at least one.
+      start = Math.max(end - overlapSize, start + 1);
+    }
+    return ranges;
+  }
+
   /**
    * Maps tokens to their vocabulary ids.
    *
@@ -366,21 +398,27 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
    * @param input An array of values.
    * @return The output array.
    */
-  private double[] softmax(final float[] input) {
+  static double[] softmax(final float[] input) {
+
+    // Subtract the maximum before exponentiating (numerically stable 
softmax): exp() of a
+    // large logit otherwise overflows to +Infinity, yielding NaN scores. 
Mathematically
+    // identical to the naive form. Results are kept in double precision 
throughout.
+    double max = Double.NEGATIVE_INFINITY;
+    for (final float value : input) {
+      max = Math.max(max, value);
+    }
 
     final double[] t = new double[input.length];
     double sum = 0.0;
-
     for (int x = 0; x < input.length; x++) {
-      double val = Math.exp(input[x]);
+      final double val = Math.exp(input[x] - max);
       sum += val;
       t[x] = val;
     }
 
     final double[] output = new double[input.length];
-
     for (int x = 0; x < output.length; x++) {
-      output[x] = (float) (t[x] / sum);
+      output[x] = t[x] / sum;
     }
 
     return output;
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
index a6bab39f6..e962e4bee 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
@@ -18,13 +18,18 @@
 package opennlp.dl.doccat;
 
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import org.junit.jupiter.api.Test;
 
+import opennlp.dl.InferenceOptions;
+import opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy;
 import opennlp.tools.tokenize.WordpieceTokenizer;
 
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
@@ -40,6 +45,47 @@ public class DocumentCategorizerDLTest {
     return vocab;
   }
 
+  private static Map<Integer, String> categories() {
+    final Map<Integer, String> categories = new HashMap<>();
+    categories.put(0, "negative");
+    categories.put(1, "positive");
+    return categories;
+  }
+
+  private static DocumentCategorizerDL categorizerWithoutSession() {
+    return new DocumentCategorizerDL(null, null, vocab(), categories(),
+        new AverageClassificationScoringStrategy(), new InferenceOptions());
+  }
+
+  @Test
+  void testCategorizeFailsLoudlyWhenInferenceFails() {
+    final IllegalStateException e = assertThrows(IllegalStateException.class, 
() ->
+        categorizerWithoutSession().categorize(new String[] {"hello world"}));
+
+    assertTrue(e.getMessage().contains("document classification inference"));
+    assertTrue(e.getCause() instanceof RuntimeException);
+  }
+
+  @Test
+  void testScoreMapsFailLoudlyWhenInferenceFails() {
+    final DocumentCategorizerDL categorizer = categorizerWithoutSession();
+
+    assertThrows(IllegalStateException.class, () ->
+        categorizer.scoreMap(new String[] {"hello world"}));
+    assertThrows(IllegalStateException.class, () ->
+        categorizer.sortedScoreMap(new String[] {"hello world"}));
+  }
+
+  @Test
+  void testCategorizeRejectsMalformedInput() {
+    // A caller-side input bug is distinguished from an inference failure: it 
is rejected up front
+    // with IllegalArgumentException, not wrapped as "document classification 
inference" failure.
+    final DocumentCategorizerDL categorizer = categorizerWithoutSession();
+
+    assertThrows(IllegalArgumentException.class, () -> 
categorizer.categorize(null));
+    assertThrows(IllegalArgumentException.class, () -> 
categorizer.categorize(new String[0]));
+  }
+
   @Test
   void testTokenIdsMapsTokensToVocabularyIds() {
     final long[] ids = DocumentCategorizerDL.tokenIds(
@@ -57,4 +103,83 @@ public class DocumentCategorizerDLTest {
     assertTrue(e.getMessage().contains("missing"),
         "the error message should name the missing token: " + e.getMessage());
   }
+
+  @Test
+  void testSoftmaxIsUniformForEqualLogitsAndSumsToOne() {
+    final double[] out = DocumentCategorizerDL.softmax(new float[] {0f, 0f, 
0f});
+
+    assertEquals(3, out.length);
+    for (final double p : out) {
+      assertEquals(1.0 / 3.0, p, 1e-12);
+    }
+    assertEquals(1.0, out[0] + out[1] + out[2], 1e-12);
+  }
+
+  @Test
+  void testSoftmaxIsNumericallyStableForLargeLogits() {
+    // The naive exp(logit) form overflows to +Infinity here and yields NaN; 
subtracting
+    // the maximum keeps every value finite and the distribution uniform.
+    final double[] out = DocumentCategorizerDL.softmax(new float[] {1000f, 
1000f, 1000f});
+
+    double sum = 0.0;
+    for (final double p : out) {
+      assertFalse(Double.isNaN(p) || Double.isInfinite(p),
+          "softmax must stay finite for large logits");
+      assertEquals(1.0 / 3.0, p, 1e-9);
+      sum += p;
+    }
+    assertEquals(1.0, sum, 1e-12);
+  }
+
+  @Test
+  void testSoftmaxMatchesReferenceDistribution() {
+    // Reference (numpy): softmax([1,2,3]) = [0.09003057, 0.24472847, 
0.66524096].
+    final double[] out = DocumentCategorizerDL.softmax(new float[] {1f, 2f, 
3f});
+
+    assertEquals(0.09003057, out[0], 1e-6);
+    assertEquals(0.24472847, out[1], 1e-6);
+    assertEquals(0.66524096, out[2], 1e-6);
+  }
+
+  @Test
+  void testChunkRangesSplitsWithOverlap() {
+    // 210 words, 200-word chunks overlapping by 50 -> [0,200), [150,210).
+    final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(210, 200, 50);
+
+    assertEquals(2, ranges.size());
+    assertArrayEquals(new int[] {0, 200}, ranges.get(0));
+    assertArrayEquals(new int[] {150, 210}, ranges.get(1));
+  }
+
+  @Test
+  void testChunkRangesSingleChunkWhenShorterThanSplit() {
+    final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(30, 200, 50);
+
+    assertEquals(1, ranges.size());
+    assertArrayEquals(new int[] {0, 30}, ranges.get(0));
+  }
+
+  @Test
+  void testChunkRangesEmptyForZeroLength() {
+    assertTrue(DocumentCategorizerDL.chunkRanges(0, 200, 50).isEmpty());
+  }
+
+  @Test
+  void testChunkRangesAlwaysProgressesForInvalidOverlap() {
+    // overlap == split would stall forever, and overlap > split would make 
the start index
+    // negative, without the forward-progress guard.
+    for (final int[] cfg : new int[][] {{10, 5, 5}, {8, 3, 10}, {7, 4, 100}}) {
+      final int length = cfg[0];
+      final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(length, 
cfg[1], cfg[2]);
+
+      int previousStart = -1;
+      for (final int[] range : ranges) {
+        assertTrue(range[0] >= 0, "start must never be negative: " + range[0]);
+        assertTrue(range[1] >= range[0], "end must be >= start");
+        assertTrue(range[0] > previousStart, "each chunk must advance the 
start index");
+        previousStart = range[0];
+      }
+      assertEquals(length, ranges.get(ranges.size() - 1)[1], "last chunk must 
reach the end");
+    }
+  }
 }
diff --git 
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
 
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index e045ec90e..d19443b5b 100644
--- 
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++ 
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -34,6 +34,7 @@ import org.slf4j.LoggerFactory;
 import opennlp.dl.InferenceOptions;
 import opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy;
 import opennlp.tools.eval.AbstractEvalTest;
+import opennlp.tools.tokenize.WordpieceTokenizer;
 
 public class DocumentCategorizerDLEval extends AbstractEvalTest {
 
@@ -91,6 +92,27 @@ public class DocumentCategorizerDLEval extends 
AbstractEvalTest {
 
   }
 
+  @Test
+  public void categorizeFailsLoudlyOnFailure() throws Exception {
+
+    try (final DocumentCategorizerDL documentCategorizerDL =
+             categorizerWithoutSession()) {
+
+      // Empty input drives categorize() down its failure path (strings[0] 
throws) before any
+      // inference; it must fail loudly rather than return an invalid all-zero 
distribution.
+      final IllegalStateException e = 
Assertions.assertThrows(IllegalStateException.class, () ->
+          documentCategorizerDL.categorize(new String[0]));
+      Assertions.assertTrue(e.getMessage().contains("document classification 
inference"));
+
+      // The dependent API must not mask that inference failure with all-zero 
scores.
+      Assertions.assertThrows(IllegalStateException.class, () ->
+          documentCategorizerDL.scoreMap(new String[0]));
+      Assertions.assertThrows(IllegalStateException.class, () ->
+          documentCategorizerDL.sortedScoreMap(new String[0]));
+    }
+
+  }
+
   @Test
   public void categorizeWithAutomaticLabels() throws Exception {
 
@@ -309,4 +331,19 @@ public class DocumentCategorizerDLEval extends 
AbstractEvalTest {
 
   }
 
+  private DocumentCategorizerDL categorizerWithoutSession() {
+    return new DocumentCategorizerDL(null, null, vocab(), getCategories(),
+        new AverageClassificationScoringStrategy(), new InferenceOptions());
+  }
+
+  private Map<String, Integer> vocab() {
+    final Map<String, Integer> vocab = new HashMap<>();
+    vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0);
+    vocab.put(WordpieceTokenizer.BERT_SEP_TOKEN, 1);
+    vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
+    vocab.put("hello", 3);
+    vocab.put("world", 4);
+    return vocab;
+  }
+
 }

Reply via email to