This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-qa-example in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 37e7b8483f07e0fb0f911cb5d149e6dadea007ea Author: gigasquid <[email protected]> AuthorDate: Fri Apr 12 19:18:20 2019 -0400 qa example working --- .../examples/bert-qa/src/bert_qa/core.clj | 42 ++++++++++++---------- .../examples/bert-qa/src/java/BertQA.java | 3 ++ 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj index 1876a83..1c7b14c 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj @@ -59,6 +59,22 @@ (mapv #(get idx2token %) idxs)) +(defn post-processing [result tokens] + (let [output1 (ndarray/slice-axis result 2 0 1) + output2 (ndarray/slice-axis result 2 1 2) + ;;; get the formatted logits result + start-logits (ndarray/reshape output1 [0 -3]) + end-logits (ndarray/reshape output2 [0 -3]) + start-prob (ndarray/softmax start-logits) + end-prob (ndarray/softmax end-logits) + start-idx (-> (ndarray/argmax start-prob 1) + (ndarray/->vec) + (first)) + end-idx (-> (ndarray/argmax end-prob 1) + (ndarray/->vec) + (first))] + (subvec tokens (dec start-idx) (inc end-idx)))) + (defn infer [] (let [ctx (context/default-context) ;;; pre-processing tokenize sentence @@ -99,25 +115,15 @@ factory {:contexts [ctx] :epoch 2}) - result (first (infer/predict-with-ndarray predictor input-batch))] - result) - - ) + ;;; start predication + result (first (infer/predict-with-ndarray predictor input-batch)) + answer (post-processing result tokens)] + (println "Question: " input-q) + (println "Answer paragraph: " input-a) + (println "Answer: " answer))) (comment - (repeat 3 0) - (def x ) - (keys x) - (get x "idx_to_token") - (def bert-parser (new BertDataParser)) - (.parseJSON bert-parser model-vocab) - (.token2idx bert-parser (java.util.ArrayList. ["and" "where"])) - [1998 2073] - (.idx2token bert-parser (java.util.ArrayList. (map int [1998 2073]))) - ["and" "where"] - (def bert-qa (new BertQA)) - - (BertQA/main (into-array ["--model-path-prefix" "model/static_bert_qa"]) ) - (r/reflect bert-qa)) + (infer) +) diff --git a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java index 7308bca..8521f0b 100644 --- a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java +++ b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java @@ -69,6 +69,7 @@ public class BertQA { static List<String> postProcessing(NDArray result, List<String> tokens) { NDArray[] output = NDArray.split( NDArray.new splitParam(result, 2).setAxis(2)); + logger.info("Carin postprocessing output: " + Arrays.toString(output)); // Get the formatted logits result NDArray startLogits = output[0].reshape(new int[]{0, -3}); NDArray endLogits = output[1].reshape(new int[]{0, -3}); @@ -79,6 +80,8 @@ public class BertQA { NDArray.new softmaxParam(endLogits))[0].toArray(); int startIdx = argmax(startProb); int endIdx = argmax(endProb); + logger.info("Carin startIdx "+ startIdx); + logger.info("Carin endIdx "+ startIdx); return tokens.subList(startIdx, endIdx + 1); }
