This is an automated email from the ASF dual-hosted git repository.
jzemerick pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 9b62d587 OPENNLP-1442: Sentence transformers (#523)
9b62d587 is described below
commit 9b62d587106d6bf7b4ccfdb67e4fa4378c533941
Author: Jeff Zemerick <[email protected]>
AuthorDate: Fri Mar 31 15:19:00 2023 -0400
OPENNLP-1442: Sentence transformers (#523)
* OPENNLP-1442: Adding sentence transformer model support via ONNX Runtime.
---
opennlp-dl/README.md | 56 ++++++-----
.../src/main/java/opennlp/dl/AbstractDL.java | 72 +++++++++++++
.../opennlp/dl/doccat/DocumentCategorizerDL.java | 56 ++---------
.../java/opennlp/dl/namefinder/NameFinderDL.java | 46 +--------
.../java/opennlp/dl/vectors/SentenceVectorsDL.java | 112 +++++++++++++++++++++
.../src/test/java/opennlp/dl/AbstractDLTest.java | 2 +-
.../opennlp/dl/vectors/SentenceVectorsDLEval.java | 50 +++++++++
7 files changed, 278 insertions(+), 116 deletions(-)
diff --git a/opennlp-dl/README.md b/opennlp-dl/README.md
index 1c64a76c..21374e33 100644
--- a/opennlp-dl/README.md
+++ b/opennlp-dl/README.md
@@ -4,44 +4,50 @@ This module provides OpenNLP interface implementations for
ONNX models using the
**Important**: This does not provide the ability to train models. Model
training is done outside of OpenNLP. This code provides the ability to use ONNX
models from OpenNLP.
-To build with example models, download the models to the `/src/test/resources`
directory. (These are the exported models described below.)
+Models used in the tests are available in the opennlp evaluation test data.
+## NameFinderDL
+
+Export a Huggingface NER model to ONNX, e.g.:
+
+```bash
+python -m transformers.onnx --model=dslim/bert-base-NER --feature
token-classification exported
```
-export OPENNLP_DATA=/tmp/
-mkdir /tmp/dl-doccat /tmp/dl-namefinder
+## DocumentCategorizerDL
-# Document categorizer model
-wget https://www.dropbox.com/s/n9uzs8r4xm9rhxb/model.onnx?dl=0 -O
$OPENNLP_DATA/dl-doccat/model.onnx
-wget https://www.dropbox.com/s/aw6yjc68jw0jts6/vocab.txt?dl=0 -O
$OPENNLP_DATA/dl-doccat/vocab.txt
+Export a Huggingface classification (e.g. sentiment) model to ONNX, e.g.:
-# Namefinder model
-wget https://www.dropbox.com/s/zgogq65gs9tyfm1/model.onnx?dl=0 -O
$OPENNLP_DATA/dl-namefinder/model.onnx
-wget https://www.dropbox.com/s/3byt1jggly1dg98/vocab.txt?dl=0 -O
$OPENNLP_DATA/dl-/namefinder/vocab.txt
+```bash
+python -m transformers.onnx
--model=nlptown/bert-base-multilingual-uncased-sentiment --feature
sequence-classification exported
```
-## TokenNameFinder
+## SentenceVectors
-* Export a Huggingface NER model to ONNX, e.g.:
+Convert a sentence vectors model to ONNX, e.g.:
-```
-python -m transformers.onnx --model=dslim/bert-base-NER --feature
token-classification exported
-```
+Install dependencies:
-* Copy the exported model to `src/test/resources/namefinder/model.onnx`.
-* Copy the model's
[vocab.txt](https://huggingface.co/dslim/bert-base-NER/tree/main) to
`src/test/resources/namefinder/vocab.txt`.
+```bash
+python3 -m pip install optimum onnx onnxruntime
+```
-Now you can run the tests in `NameFinderDLTest`.
+Convert the model:
-## DocumentCategorizer
+```python
+from optimum.onnxruntime import ORTModelForFeatureExtraction
+from transformers import AutoTokenizer
+from pathlib import Path
-* Export a Huggingface classification (e.g. sentiment) model to ONNX, e.g.:
-```
-python -m transformers.onnx
--model=nlptown/bert-base-multilingual-uncased-sentiment --feature
sequence-classification exported
-```
+model_id="sentence-transformers/all-MiniLM-L6-v2"
+onnx_path = Path("onnx")
-* Copy the exported model to `src/test/resources/doccat/model.onnx`.
-* Copy the model's
[vocab.txt](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment/tree/main)
to `src/test/resources/namefinder/vocab.txt`.
+# load vanilla transformers and convert to onnx
+model = ORTModelForFeatureExtraction.from_pretrained(model_id,
from_transformers=True)
+tokenizer = AutoTokenizer.from_pretrained(model_id)
-Now you can run the tests in `DocumentCategorizerDLTest`.
\ No newline at end of file
+# save onnx checkpoint and tokenizer
+model.save_pretrained(onnx_path)
+tokenizer.save_pretrained(onnx_path)
+```
\ No newline at end of file
diff --git a/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
new file mode 100644
index 00000000..2d22a359
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Stream;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtSession;
+
+import opennlp.tools.tokenize.Tokenizer;
+
+/**
+ * Base class for OpenNLP deep-learning classes using ONNX Runtime.
+ */
+public abstract class AbstractDL {
+
+ public static final String INPUT_IDS = "input_ids";
+ public static final String ATTENTION_MASK = "attention_mask";
+ public static final String TOKEN_TYPE_IDS = "token_type_ids";
+
+ protected OrtEnvironment env;
+ protected OrtSession session;
+ protected Tokenizer tokenizer;
+ protected Map<String, Integer> vocab;
+
+ /**
+ * Loads a vocabulary file from disk.
+ * @param vocabFile The vocabulary file.
+ * @return A map of vocabulary words to integer IDs.
+ * @throws IOException Thrown if the vocabulary file cannot be opened and
read.
+ */
+ public Map<String, Integer> loadVocab(final File vocabFile) throws
IOException {
+
+ final Map<String, Integer> vocab = new HashMap<>();
+
+ final AtomicInteger counter = new AtomicInteger(0);
+
+ try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
+
+ lines.forEach(line -> {
+ vocab.put(line, counter.getAndIncrement());
+ });
+
+ }
+
+ return vocab;
+
+ }
+
+}
diff --git
a/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
b/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index aff9a929..6c811a14 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -17,9 +17,7 @@
package opennlp.dl.doccat;
-import java.io.BufferedReader;
import java.io.File;
-import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
@@ -40,43 +38,36 @@ import ai.onnxruntime.OrtSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
-import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
/**
* An implementation of {@link DocumentCategorizer} that performs document
classification
* using ONNX models.
*/
-public class DocumentCategorizerDL implements DocumentCategorizer {
+public class DocumentCategorizerDL extends AbstractDL implements
DocumentCategorizer {
private static final Logger logger =
LoggerFactory.getLogger(DocumentCategorizerDL.class);
- public static final String INPUT_IDS = "input_ids";
- public static final String ATTENTION_MASK = "attention_mask";
- public static final String TOKEN_TYPE_IDS = "token_type_ids";
- private final Tokenizer tokenizer;
- private final Map<String, Integer> vocabulary;
private final Map<Integer, String> categories;
private final ClassificationScoringStrategy classificationScoringStrategy;
private final InferenceOptions inferenceOptions;
- protected final OrtEnvironment env;
- protected final OrtSession session;
/**
* Creates a new document categorizer using ONNX models.
- * @param model The ONNX model file.
- * @param vocab The model's vocabulary file.
+ * @param modelFile The ONNX modelFile file.
+ * @param vocabFile The modelFile's vocabulary file.
* @param categories The categories.
* @param classificationScoringStrategy Implementation of {@link
ClassificationScoringStrategy} used
* to calculate the classification
scores given the score of each
* individual document part.
* @param inferenceOptions {@link InferenceOptions} to control the inference.
*/
- public DocumentCategorizerDL(File model, File vocab, Map<Integer, String>
categories,
+ public DocumentCategorizerDL(File modelFile, File vocabFile, Map<Integer,
String> categories,
ClassificationScoringStrategy
classificationScoringStrategy,
InferenceOptions inferenceOptions)
throws IOException, OrtException {
@@ -88,9 +79,9 @@ public class DocumentCategorizerDL implements
DocumentCategorizer {
sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
}
- this.session = env.createSession(model.getPath(), sessionOptions);
- this.vocabulary = loadVocab(vocab);
- this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
+ this.session = env.createSession(modelFile.getPath(), sessionOptions);
+ this.vocab = loadVocab(vocabFile);
+ this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
@@ -223,33 +214,6 @@ public class DocumentCategorizerDL implements
DocumentCategorizer {
}
- /**
- * Loads a vocabulary file from disk.
- * @param vocab The vocabulary file.
- * @return A map of vocabulary words to integer IDs.
- * @throws IOException Thrown if the vocabulary file cannot be opened and
read.
- */
- private Map<String, Integer> loadVocab(File vocab) throws IOException {
-
- final Map<String, Integer> v = new HashMap<>();
-
- BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()));
- String line = br.readLine();
- int x = 0;
-
- while (line != null) {
-
- line = br.readLine();
- x++;
-
- v.put(line, x);
-
- }
-
- return v;
-
- }
-
private Tokens oldTokenize(String text) {
final String[] tokens = tokenizer.tokenize(text);
@@ -257,7 +221,7 @@ public class DocumentCategorizerDL implements
DocumentCategorizer {
final int[] ids = new int[tokens.length];
for (int x = 0; x < tokens.length; x++) {
- ids[x] = vocabulary.get(tokens[x]);
+ ids[x] = vocab.get(tokens[x]);
}
final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
@@ -306,7 +270,7 @@ public class DocumentCategorizerDL implements
DocumentCategorizer {
final int[] ids = new int[tokens.length];
for (int x = 0; x < tokens.length; x++) {
- ids[x] = vocabulary.get(tokens[x]);
+ ids[x] = vocab.get(tokens[x]);
}
final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
diff --git a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index 1049353e..ba17e5fd 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -17,10 +17,7 @@
package opennlp.dl.namefinder;
-import java.io.BufferedReader;
import java.io.File;
-import java.io.FileReader;
-import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
@@ -35,23 +32,19 @@ import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
-import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;
/**
* An implementation of {@link TokenNameFinder} that uses ONNX models.
*/
-public class NameFinderDL implements TokenNameFinder {
-
- public static final String INPUT_IDS = "input_ids";
- public static final String ATTENTION_MASK = "attention_mask";
- public static final String TOKEN_TYPE_IDS = "token_type_ids";
+public class NameFinderDL extends AbstractDL implements TokenNameFinder {
public static final String I_PER = "I-PER";
public static final String B_PER = "B-PER";
@@ -59,14 +52,9 @@ public class NameFinderDL implements TokenNameFinder {
private static final String CHARS_TO_REPLACE = "##";
- protected final OrtSession session;
-
private final SentenceDetector sentenceDetector;
private final Map<Integer, String> ids2Labels;
- private final Tokenizer tokenizer;
- private final Map<String, Integer> vocab;
private final InferenceOptions inferenceOptions;
- protected final OrtEnvironment env;
public NameFinderDL(File model, File vocabulary, Map<Integer, String>
ids2Labels,
SentenceDetector sentenceDetector) throws Exception {
@@ -384,34 +372,4 @@ public class NameFinderDL implements TokenNameFinder {
}
- /**
- * Loads a vocabulary file from disk.
- * @param vocab The vocabulary file.
- * @return A map of vocabulary words to integer IDs.
- * @throws IOException Thrown if the vocabulary file cannot be opened and
read.
- */
- private Map<String, Integer> loadVocab(File vocab) throws IOException {
-
- final Map<String, Integer> v = new HashMap<>();
-
- try (final BufferedReader br = new BufferedReader(new
FileReader(vocab.getPath()))) {
-
- String line = br.readLine();
- int x = 0;
-
- while (line != null) {
-
- line = br.readLine();
- x++;
-
- v.put(line, x);
-
- }
-
- }
-
- return v;
-
- }
-
}
diff --git a/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
b/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
new file mode 100644
index 00000000..959e6d89
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.vectors;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.LongBuffer;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import opennlp.dl.AbstractDL;
+import opennlp.dl.Tokens;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
+
+/**
+ * Facilitates the generation of sentence vectors using
+ * a sentence-transformers model converted to ONNX.
+ */
+public class SentenceVectorsDL extends AbstractDL {
+
+ private static final Logger logger =
LoggerFactory.getLogger(SentenceVectorsDL.class);
+
+ /**
+ * Creates an instance of the class.
+ * @param model The file name of a sentence vectors ONNX model.
+ * @param vocabulary The file name of the vocabulary file for the model.
+ * @throws OrtException Thrown if the model cannot be loaded.
+ * @throws IOException Thrown if the vocabulary file cannot be loaded.
+ */
+ public SentenceVectorsDL(final File model, final File vocabulary)
+ throws OrtException, IOException {
+
+ env = OrtEnvironment.getEnvironment();
+ session = env.createSession(model.getPath(), new
OrtSession.SessionOptions());
+ vocab = loadVocab(new File(vocabulary.getPath()));
+ tokenizer = new WordpieceTokenizer(vocab.keySet());
+
+ }
+
+ /**
+ * Generates vectors given a sentence.
+ * @param sentence The input sentence.
+ * @throws OrtException Thrown if an error occurs during inference.
+ */
+ public float[] getVectors(final String sentence) throws OrtException {
+
+ final Tokens tokens = tokenize(sentence, tokenizer, vocab);
+
+ final Map<String, OnnxTensor> inputs = new HashMap<>();
+
+ inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.getIds()),
+ new long[] {1, tokens.getIds().length}));
+
+ inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.getMask()), new long[] {1,
tokens.getMask().length}));
+
+ inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(tokens.getTypes()), new long[] {1,
tokens.getTypes().length}));
+
+ final float[][][] v = (float[][][]) session.run(inputs).get(0).getValue();
+
+ final float[] vectors = v[0][0];
+
+ return vectors;
+
+ }
+
+ private Tokens tokenize(final String text, Tokenizer tokenizer, Map<String,
Integer> vocab) {
+
+ final String[] tokens = tokenizer.tokenize(text);
+
+ final int[] ids = new int[tokens.length];
+ final long[] mask = new long[ids.length];
+
+ for (int x = 0; x < tokens.length; x++) {
+ ids[x] = vocab.get(tokens[x]);
+ }
+
+ final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
+
+ final long[] types = new long[ids.length];
+ Arrays.fill(types, 1);
+
+ return new Tokens(tokens, lids, mask, types);
+
+ }
+
+}
diff --git a/opennlp-dl/src/test/java/opennlp/dl/AbstractDLTest.java
b/opennlp-dl/src/test/java/opennlp/dl/AbstractDLTest.java
index 68d139e5..b5f694e5 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/AbstractDLTest.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/AbstractDLTest.java
@@ -24,7 +24,7 @@ import opennlp.tools.util.StringUtil;
public abstract class AbstractDLTest {
- public static File getOpennlpDataDir() throws FileNotFoundException {
+ public File getOpennlpDataDir() throws FileNotFoundException {
final String dataDirectory = System.getProperty("OPENNLP_DATA_DIR");
if (dataDirectory == null || StringUtil.isEmpty(dataDirectory)) {
throw new IllegalArgumentException("The OPENNLP_DATA_DIR is not set.");
diff --git
a/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
b/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
new file mode 100644
index 00000000..f63fa3f4
--- /dev/null
+++ b/opennlp-dl/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl.vectors;
+
+import java.io.File;
+import java.io.IOException;
+
+import ai.onnxruntime.OrtException;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import opennlp.dl.AbstractDLTest;
+
+public class SentenceVectorsDLEval extends AbstractDLTest {
+
+ @Test
+ public void generateVectorsTest() throws IOException, OrtException {
+
+ final File MODEL_FILE_NAME = new File(getOpennlpDataDir(),
"onnx/sentence-transformers/model.onnx");
+ final File VOCAB_FILE_NAME = new File(getOpennlpDataDir(),
"onnx/sentence-transformers/vocab.txt");
+
+ final String sentence = "george washington was president";
+
+ final SentenceVectorsDL sv = new SentenceVectorsDL(MODEL_FILE_NAME,
VOCAB_FILE_NAME);
+
+ final float[] vectors = sv.getVectors(sentence);
+
+ Assertions.assertEquals(vectors[0], 0.39994872, 0.00001);
+ Assertions.assertEquals(vectors[1], -0.055101186, 0.00001);
+ Assertions.assertEquals(vectors[2], 0.2817594, 0.00001);
+ Assertions.assertEquals(vectors.length, 384);
+
+ }
+
+}