This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-qa-example
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 1106bacaa96514af92abac57d171b4cb3934ddb8
Author: gigasquid <[email protected]>
AuthorDate: Fri Apr 12 19:27:58 2019 -0400

    cleaning up example
---
 .../clojure-package/examples/bert-qa/project.clj   |   2 +-
 .../examples/bert-qa/src/bert_qa/core.clj          |  15 +-
 .../examples/bert-qa/src/java/BertDataParser.java  | 126 -----------------
 .../examples/bert-qa/src/java/BertQA.java          | 152 ---------------------
 4 files changed, 9 insertions(+), 286 deletions(-)

diff --git a/contrib/clojure-package/examples/bert-qa/project.clj 
b/contrib/clojure-package/examples/bert-qa/project.clj
index 328d040..5bec165 100644
--- a/contrib/clojure-package/examples/bert-qa/project.clj
+++ b/contrib/clojure-package/examples/bert-qa/project.clj
@@ -3,8 +3,8 @@
   :plugins [[lein-cljfmt "0.5.7"]]
   :dependencies [[org.clojure/clojure "1.9.0"]
                  [org.apache.mxnet.contrib.clojure/clojure-mxnet 
"1.5.0-SNAPSHOT"]
-                 [com.google.code.gson/gson "2.8.5"]
                  [cheshire "5.8.1"]]
   :pedantic? :skip
   :java-source-paths ["src/java"]
+  :main bert-qa.core
   :repl-options {:init-ns bert-qa.core})
diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
index 1c7b14c..02f2d34 100644
--- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
+++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
@@ -8,9 +8,7 @@
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
-            [org.apache.clojure-mxnet.infer :as infer])
-  (:import (bert BertDataParser)
-           (bert BertQA)))
+            [org.apache.clojure-mxnet.infer :as infer]))
 
 (def model-path-prefix "model/static_bert_qa")
 ;; epoch number of the model
@@ -58,7 +56,6 @@
 (defn idxs->tokens [idx2token idxs]
   (mapv #(get idx2token %) idxs))
 
-
 (defn post-processing [result tokens]
   (let [output1 (ndarray/slice-axis result 2 0 1)
         output2 (ndarray/slice-axis result 2 1 2)
@@ -75,9 +72,8 @@
                     (first))]
     (subvec tokens (dec start-idx) (inc end-idx))))
 
-(defn infer []
-  (let [ctx (context/default-context)
-        ;;; pre-processing tokenize sentence
+(defn infer [ctx]
+  (let [;;; pre-processing tokenize sentence
         token-q (tokenizer (string/lower-case input-q))
         token-a (tokenizer (string/lower-case input-a))
         valid-length (+ (count token-q) (count token-a))
@@ -122,6 +118,11 @@
     (println "Answer paragraph: " input-a)
     (println "Answer: " answer)))
 
+(defn -main [& args]
+  (let [[dev] args]
+            (if (= dev ":gpu")
+              (infer (context/gpu))
+              (infer (context/cpu)))))
 
 (comment
 
diff --git 
a/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java 
b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java
deleted file mode 100644
index a0a821a..0000000
--- a/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mxnetexamples.javaapi.infer.bert;
-
-import java.io.FileReader;
-import java.util.*;
-
-import com.google.gson.Gson;
-import com.google.gson.JsonArray;
-import com.google.gson.JsonElement;
-import com.google.gson.JsonObject;
-
-/**
- * This is the Utility for pre-processing the data for Bert Model
- * You can use this utility to parse Vocabulary JSON into Java Array and 
Dictionary,
- * clean and tokenize sentences and pad the text
- */
-public class BertDataParser {
-
-    private Map<String, Integer> token2idx;
-    private List<String> idx2token;
-
-    /**
-     * Parse the Vocabulary to JSON files
-     * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens
-     * @param jsonFile the filePath of the vocab.json
-     * @throws Exception
-     */
-    public void parseJSON(String jsonFile) throws Exception {
-        Gson gson = new Gson();
-        token2idx = new HashMap<>();
-        idx2token = new LinkedList<>();
-        JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), 
JsonObject.class);
-        JsonArray arr = jsonObject.getAsJsonArray("idx_to_token");
-        for (JsonElement element : arr) {
-            idx2token.add(element.getAsString());
-        }
-        JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx");
-        for (String key : preMap.keySet()) {
-            token2idx.put(key, preMap.get(key).getAsInt());
-        }
-    }
-
-    /**
-     * Tokenize the input, split all kinds of whitespace and
-     * Separate the end of sentence symbol: . , ? !
-     * @param input The input string
-     * @return List of tokens
-     */
-    public List<String> tokenizer(String input) {
-        String[] step1 = input.split("\\s+");
-        List<String> finalResult = new LinkedList<>();
-        for (String item : step1) {
-            if (item.length() != 0) {
-                if ((item + "a").split("[.,?!]+").length > 1) {
-                    finalResult.add(item.substring(0, item.length() - 1));
-                    finalResult.add(item.substring(item.length() -1));
-                } else {
-                    finalResult.add(item);
-                }
-            }
-        }
-        return finalResult;
-    }
-
-    /**
-     * Pad the tokens to the required length
-     * @param tokens input tokens
-     * @param padItem things to pad at the end
-     * @param num total length after padding
-     * @return List of padded tokens
-     */
-    public <E> List<E> pad(List<E> tokens, E padItem, int num) {
-        if (tokens.size() >= num) return tokens;
-        List<E> padded = new LinkedList<>(tokens);
-        for (int i = 0; i < num - tokens.size(); i++) {
-            padded.add(padItem);
-        }
-        return padded;
-    }
-
-    /**
-     * Convert tokens to indexes
-     * @param tokens input tokens
-     * @return List of indexes
-     */
-    public List<Integer> token2idx(List<String> tokens) {
-        List<Integer> indexes = new ArrayList<>();
-        for (String token : tokens) {
-            if (token2idx.containsKey(token)) {
-                indexes.add(token2idx.get(token));
-            } else {
-                indexes.add(token2idx.get("[UNK]"));
-            }
-        }
-        return indexes;
-    }
-
-    /**
-     * Convert indexes to tokens
-     * @param indexes List of indexes
-     * @return List of tokens
-     */
-    public List<String> idx2token(List<Integer> indexes) {
-        List<String> tokens = new ArrayList<>();
-        for (int index : indexes) {
-            tokens.add(idx2token.get(index));
-        }
-        return tokens;
-    }
-}
diff --git a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java 
b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
deleted file mode 100644
index 8521f0b..0000000
--- a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package bert;
-
-import org.apache.mxnet.infer.javaapi.Predictor;
-import org.apache.mxnet.javaapi.*;
-import org.kohsuke.args4j.CmdLineParser;
-import org.kohsuke.args4j.Option;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.*;
-
-/**
- * This is an example of using BERT to do the general Question and Answer 
inference jobs
- * Users can provide a question with a paragraph contains answer to the model 
and
- * the model will be able to find the best answer from the answer paragraph
- */
-public class BertQA {
-    @Option(name = "--model-path-prefix", usage = "input model directory and 
prefix of the model")
-    private String modelPathPrefix = "model/static_bert_qa";
-    @Option(name = "--model-epoch", usage = "Epoch number of the model")
-    private int epoch = 2;
-    @Option(name = "--model-vocab", usage = "the vocabulary used in the model")
-    private String modelVocab = "model/vocab.json";
-    @Option(name = "--input-question", usage = "the input question")
-    private String inputQ = "When did BBC Japan start broadcasting?";
-    @Option(name = "--input-answer", usage = "the input answer")
-    private String inputA =
-        "BBC Japan was a general entertainment Channel.\n" +
-                " Which operated between December 2004 and April 2006.\n" +
-            "It ceased operations after its Japanese distributor folded.";
-    @Option(name = "--seq-length", usage = "the maximum length of the 
sequence")
-    private int seqLength = 384;
-
-    private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
-    private static NDArray$ NDArray = NDArray$.MODULE$;
-
-    private static int argmax(float[] prob) {
-        int maxIdx = 0;
-        for (int i = 0; i < prob.length; i++) {
-            if (prob[maxIdx] < prob[i]) maxIdx = i;
-        }
-        return maxIdx;
-    }
-
-    /**
-     * Do the post processing on the output, apply softmax to get the 
probabilities
-     * reshape and get the most probable index
-     * @param result prediction result
-     * @param tokens word tokens
-     * @return Answers clipped from the original paragraph
-     */
-    static List<String> postProcessing(NDArray result, List<String> tokens) {
-        NDArray[] output = NDArray.split(
-                NDArray.new splitParam(result, 2).setAxis(2));
-       logger.info("Carin postprocessing output: " + Arrays.toString(output));
-        // Get the formatted logits result
-        NDArray startLogits = output[0].reshape(new int[]{0, -3});
-        NDArray endLogits = output[1].reshape(new int[]{0, -3});
-        // Get Probability distribution
-        float[] startProb = NDArray.softmax(
-                NDArray.new softmaxParam(startLogits))[0].toArray();
-        float[] endProb = NDArray.softmax(
-                NDArray.new softmaxParam(endLogits))[0].toArray();
-        int startIdx = argmax(startProb);
-        int endIdx = argmax(endProb);
-       logger.info("Carin startIdx "+ startIdx);
-       logger.info("Carin endIdx "+ startIdx);
-        return tokens.subList(startIdx, endIdx + 1);
-    }
-
-    public static void main(String[] args) throws Exception{
-        BertQA inst = new BertQA();
-        CmdLineParser parser = new CmdLineParser(inst);
-        parser.parseArgument(args);
-        BertDataParser util = new BertDataParser();
-        Context context = Context.cpu();
-        if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
-                Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
-            context = Context.gpu();
-        }
-        // pre-processing - tokenize sentence
-        List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase());
-        List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase());
-        int validLength = tokenQ.size() + tokenA.size();
-        logger.info("Valid length: " + validLength);
-        // generate token types [0000...1111....0000]
-        List<Float> QAEmbedded = new ArrayList<>();
-        util.pad(QAEmbedded, 0f, tokenQ.size()).addAll(
-                util.pad(new ArrayList<Float>(), 1f, tokenA.size())
-        );
-        List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength);
-        // make BERT pre-processing standard
-        tokenQ.add("[SEP]");
-        tokenQ.add(0, "[CLS]");
-        tokenA.add("[SEP]");
-        tokenQ.addAll(tokenA);
-        List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength);
-        logger.info("Pre-processed tokens: " + 
Arrays.toString(tokenQ.toArray()));
-        // pre-processing - token to index translation
-        util.parseJSON(inst.modelVocab);
-        List<Integer> indexes = util.token2idx(tokens);
-        List<Float> indexesFloat = new ArrayList<>();
-        for (int integer : indexes) {
-            indexesFloat.add((float) integer);
-        }
-        // Preparing the input data
-        List<NDArray> inputBatch = Arrays.asList(
-                new NDArray(indexesFloat,
-                        new Shape(new int[]{1, inst.seqLength}), context),
-                new NDArray(tokenTypes,
-                        new Shape(new int[]{1, inst.seqLength}), context),
-                new NDArray(new float[] { validLength },
-                        new Shape(new int[]{1}), context)
-        );
-       logger.info("Carin inputbatch: " + 
Arrays.toString(inputBatch.toArray()));
-        // Build the model
-        List<Context> contexts = new ArrayList<>();
-        contexts.add(context);
-        List<DataDesc> inputDescs = Arrays.asList(
-                new DataDesc("data0",
-                        new Shape(new int[]{1, inst.seqLength}), 
DType.Float32(), Layout.NT()),
-                new DataDesc("data1",
-                        new Shape(new int[]{1, inst.seqLength}), 
DType.Float32(), Layout.NT()),
-                new DataDesc("data2",
-                        new Shape(new int[]{1}), DType.Float32(), Layout.N())
-        );
-        Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, 
contexts, inst.epoch);
-        // Start prediction
-        NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
-        List<String> answer = postProcessing(result, tokens);
-        logger.info("Question: " + inst.inputQ);
-        logger.info("Answer paragraph: " + inst.inputA);
-        logger.info("Answer: " + Arrays.toString(answer.toArray()));
-    }
-}

Reply via email to