This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-qa-example in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 1106bacaa96514af92abac57d171b4cb3934ddb8 Author: gigasquid <[email protected]> AuthorDate: Fri Apr 12 19:27:58 2019 -0400 cleaning up example --- .../clojure-package/examples/bert-qa/project.clj | 2 +- .../examples/bert-qa/src/bert_qa/core.clj | 15 +- .../examples/bert-qa/src/java/BertDataParser.java | 126 ----------------- .../examples/bert-qa/src/java/BertQA.java | 152 --------------------- 4 files changed, 9 insertions(+), 286 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/project.clj b/contrib/clojure-package/examples/bert-qa/project.clj index 328d040..5bec165 100644 --- a/contrib/clojure-package/examples/bert-qa/project.clj +++ b/contrib/clojure-package/examples/bert-qa/project.clj @@ -3,8 +3,8 @@ :plugins [[lein-cljfmt "0.5.7"]] :dependencies [[org.clojure/clojure "1.9.0"] [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"] - [com.google.code.gson/gson "2.8.5"] [cheshire "5.8.1"]] :pedantic? :skip :java-source-paths ["src/java"] + :main bert-qa.core :repl-options {:init-ns bert-qa.core}) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj index 1c7b14c..02f2d34 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj @@ -8,9 +8,7 @@ [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.ndarray :as ndarray] - [org.apache.clojure-mxnet.infer :as infer]) - (:import (bert BertDataParser) - (bert BertQA))) + [org.apache.clojure-mxnet.infer :as infer])) (def model-path-prefix "model/static_bert_qa") ;; epoch number of the model @@ -58,7 +56,6 @@ (defn idxs->tokens [idx2token idxs] (mapv #(get idx2token %) idxs)) - (defn post-processing [result tokens] (let [output1 (ndarray/slice-axis result 2 0 1) output2 (ndarray/slice-axis result 2 1 2) @@ -75,9 +72,8 @@ (first))] (subvec tokens (dec start-idx) (inc end-idx)))) -(defn infer [] - (let [ctx (context/default-context) - ;;; pre-processing tokenize sentence +(defn infer [ctx] + (let [;;; pre-processing tokenize sentence token-q (tokenizer (string/lower-case input-q)) token-a (tokenizer (string/lower-case input-a)) valid-length (+ (count token-q) (count token-a)) @@ -122,6 +118,11 @@ (println "Answer paragraph: " input-a) (println "Answer: " answer))) +(defn -main [& args] + (let [[dev] args] + (if (= dev ":gpu") + (infer (context/gpu)) + (infer (context/cpu))))) (comment diff --git a/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java deleted file mode 100644 index a0a821a..0000000 --- a/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnetexamples.javaapi.infer.bert; - -import java.io.FileReader; -import java.util.*; - -import com.google.gson.Gson; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; - -/** - * This is the Utility for pre-processing the data for Bert Model - * You can use this utility to parse Vocabulary JSON into Java Array and Dictionary, - * clean and tokenize sentences and pad the text - */ -public class BertDataParser { - - private Map<String, Integer> token2idx; - private List<String> idx2token; - - /** - * Parse the Vocabulary to JSON files - * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens - * @param jsonFile the filePath of the vocab.json - * @throws Exception - */ - public void parseJSON(String jsonFile) throws Exception { - Gson gson = new Gson(); - token2idx = new HashMap<>(); - idx2token = new LinkedList<>(); - JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class); - JsonArray arr = jsonObject.getAsJsonArray("idx_to_token"); - for (JsonElement element : arr) { - idx2token.add(element.getAsString()); - } - JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx"); - for (String key : preMap.keySet()) { - token2idx.put(key, preMap.get(key).getAsInt()); - } - } - - /** - * Tokenize the input, split all kinds of whitespace and - * Separate the end of sentence symbol: . , ? ! - * @param input The input string - * @return List of tokens - */ - public List<String> tokenizer(String input) { - String[] step1 = input.split("\\s+"); - List<String> finalResult = new LinkedList<>(); - for (String item : step1) { - if (item.length() != 0) { - if ((item + "a").split("[.,?!]+").length > 1) { - finalResult.add(item.substring(0, item.length() - 1)); - finalResult.add(item.substring(item.length() -1)); - } else { - finalResult.add(item); - } - } - } - return finalResult; - } - - /** - * Pad the tokens to the required length - * @param tokens input tokens - * @param padItem things to pad at the end - * @param num total length after padding - * @return List of padded tokens - */ - public <E> List<E> pad(List<E> tokens, E padItem, int num) { - if (tokens.size() >= num) return tokens; - List<E> padded = new LinkedList<>(tokens); - for (int i = 0; i < num - tokens.size(); i++) { - padded.add(padItem); - } - return padded; - } - - /** - * Convert tokens to indexes - * @param tokens input tokens - * @return List of indexes - */ - public List<Integer> token2idx(List<String> tokens) { - List<Integer> indexes = new ArrayList<>(); - for (String token : tokens) { - if (token2idx.containsKey(token)) { - indexes.add(token2idx.get(token)); - } else { - indexes.add(token2idx.get("[UNK]")); - } - } - return indexes; - } - - /** - * Convert indexes to tokens - * @param indexes List of indexes - * @return List of tokens - */ - public List<String> idx2token(List<Integer> indexes) { - List<String> tokens = new ArrayList<>(); - for (int index : indexes) { - tokens.add(idx2token.get(index)); - } - return tokens; - } -} diff --git a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java deleted file mode 100644 index 8521f0b..0000000 --- a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package bert; - -import org.apache.mxnet.infer.javaapi.Predictor; -import org.apache.mxnet.javaapi.*; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.*; - -/** - * This is an example of using BERT to do the general Question and Answer inference jobs - * Users can provide a question with a paragraph contains answer to the model and - * the model will be able to find the best answer from the answer paragraph - */ -public class BertQA { - @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") - private String modelPathPrefix = "model/static_bert_qa"; - @Option(name = "--model-epoch", usage = "Epoch number of the model") - private int epoch = 2; - @Option(name = "--model-vocab", usage = "the vocabulary used in the model") - private String modelVocab = "model/vocab.json"; - @Option(name = "--input-question", usage = "the input question") - private String inputQ = "When did BBC Japan start broadcasting?"; - @Option(name = "--input-answer", usage = "the input answer") - private String inputA = - "BBC Japan was a general entertainment Channel.\n" + - " Which operated between December 2004 and April 2006.\n" + - "It ceased operations after its Japanese distributor folded."; - @Option(name = "--seq-length", usage = "the maximum length of the sequence") - private int seqLength = 384; - - private final static Logger logger = LoggerFactory.getLogger(BertQA.class); - private static NDArray$ NDArray = NDArray$.MODULE$; - - private static int argmax(float[] prob) { - int maxIdx = 0; - for (int i = 0; i < prob.length; i++) { - if (prob[maxIdx] < prob[i]) maxIdx = i; - } - return maxIdx; - } - - /** - * Do the post processing on the output, apply softmax to get the probabilities - * reshape and get the most probable index - * @param result prediction result - * @param tokens word tokens - * @return Answers clipped from the original paragraph - */ - static List<String> postProcessing(NDArray result, List<String> tokens) { - NDArray[] output = NDArray.split( - NDArray.new splitParam(result, 2).setAxis(2)); - logger.info("Carin postprocessing output: " + Arrays.toString(output)); - // Get the formatted logits result - NDArray startLogits = output[0].reshape(new int[]{0, -3}); - NDArray endLogits = output[1].reshape(new int[]{0, -3}); - // Get Probability distribution - float[] startProb = NDArray.softmax( - NDArray.new softmaxParam(startLogits))[0].toArray(); - float[] endProb = NDArray.softmax( - NDArray.new softmaxParam(endLogits))[0].toArray(); - int startIdx = argmax(startProb); - int endIdx = argmax(endProb); - logger.info("Carin startIdx "+ startIdx); - logger.info("Carin endIdx "+ startIdx); - return tokens.subList(startIdx, endIdx + 1); - } - - public static void main(String[] args) throws Exception{ - BertQA inst = new BertQA(); - CmdLineParser parser = new CmdLineParser(inst); - parser.parseArgument(args); - BertDataParser util = new BertDataParser(); - Context context = Context.cpu(); - if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && - Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { - context = Context.gpu(); - } - // pre-processing - tokenize sentence - List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase()); - List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase()); - int validLength = tokenQ.size() + tokenA.size(); - logger.info("Valid length: " + validLength); - // generate token types [0000...1111....0000] - List<Float> QAEmbedded = new ArrayList<>(); - util.pad(QAEmbedded, 0f, tokenQ.size()).addAll( - util.pad(new ArrayList<Float>(), 1f, tokenA.size()) - ); - List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength); - // make BERT pre-processing standard - tokenQ.add("[SEP]"); - tokenQ.add(0, "[CLS]"); - tokenA.add("[SEP]"); - tokenQ.addAll(tokenA); - List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength); - logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray())); - // pre-processing - token to index translation - util.parseJSON(inst.modelVocab); - List<Integer> indexes = util.token2idx(tokens); - List<Float> indexesFloat = new ArrayList<>(); - for (int integer : indexes) { - indexesFloat.add((float) integer); - } - // Preparing the input data - List<NDArray> inputBatch = Arrays.asList( - new NDArray(indexesFloat, - new Shape(new int[]{1, inst.seqLength}), context), - new NDArray(tokenTypes, - new Shape(new int[]{1, inst.seqLength}), context), - new NDArray(new float[] { validLength }, - new Shape(new int[]{1}), context) - ); - logger.info("Carin inputbatch: " + Arrays.toString(inputBatch.toArray())); - // Build the model - List<Context> contexts = new ArrayList<>(); - contexts.add(context); - List<DataDesc> inputDescs = Arrays.asList( - new DataDesc("data0", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), - new DataDesc("data1", - new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), - new DataDesc("data2", - new Shape(new int[]{1}), DType.Float32(), Layout.N()) - ); - Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); - // Start prediction - NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); - List<String> answer = postProcessing(result, tokens); - logger.info("Question: " + inst.inputQ); - logger.info("Answer paragraph: " + inst.inputA); - logger.info("Answer: " + Arrays.toString(answer.toArray())); - } -}
