This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-qa-example
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 37e7b8483f07e0fb0f911cb5d149e6dadea007ea
Author: gigasquid <[email protected]>
AuthorDate: Fri Apr 12 19:18:20 2019 -0400

    qa example working
---
 .../examples/bert-qa/src/bert_qa/core.clj          | 42 ++++++++++++----------
 .../examples/bert-qa/src/java/BertQA.java          |  3 ++
 2 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
index 1876a83..1c7b14c 100644
--- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
+++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
@@ -59,6 +59,22 @@
   (mapv #(get idx2token %) idxs))
 
 
+(defn post-processing [result tokens]
+  (let [output1 (ndarray/slice-axis result 2 0 1)
+        output2 (ndarray/slice-axis result 2 1 2)
+        ;;; get the formatted logits result
+        start-logits (ndarray/reshape output1 [0 -3])
+        end-logits (ndarray/reshape output2 [0 -3])
+        start-prob (ndarray/softmax start-logits)
+        end-prob (ndarray/softmax end-logits)
+        start-idx (-> (ndarray/argmax start-prob 1)
+                      (ndarray/->vec)
+                      (first))
+        end-idx (-> (ndarray/argmax end-prob 1)
+                    (ndarray/->vec)
+                    (first))]
+    (subvec tokens (dec start-idx) (inc end-idx))))
+
 (defn infer []
   (let [ctx (context/default-context)
         ;;; pre-processing tokenize sentence
@@ -99,25 +115,15 @@
                    factory
                    {:contexts [ctx]
                     :epoch 2})
-        result (first (infer/predict-with-ndarray predictor input-batch))]
-    result)
-  
-  )
+        ;;; start predication
+        result (first (infer/predict-with-ndarray predictor input-batch))
+        answer (post-processing result tokens)]
+    (println "Question: " input-q)
+    (println "Answer paragraph: " input-a)
+    (println "Answer: " answer)))
 
 
 (comment
 
-  (repeat 3 0)
-  (def x )
-  (keys x)
-  (get x "idx_to_token")
-  (def bert-parser (new BertDataParser))
-  (.parseJSON bert-parser model-vocab)
-  (.token2idx bert-parser (java.util.ArrayList. ["and" "where"]))
-  [1998 2073]
-  (.idx2token bert-parser (java.util.ArrayList. (map int [1998 2073])))
-  ["and" "where"]
-  (def bert-qa (new BertQA))
-
-  (BertQA/main (into-array ["--model-path-prefix" "model/static_bert_qa"]) )
-  (r/reflect bert-qa))
+  (infer)
+)
diff --git a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java 
b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
index 7308bca..8521f0b 100644
--- a/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
+++ b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
@@ -69,6 +69,7 @@ public class BertQA {
     static List<String> postProcessing(NDArray result, List<String> tokens) {
         NDArray[] output = NDArray.split(
                 NDArray.new splitParam(result, 2).setAxis(2));
+       logger.info("Carin postprocessing output: " + Arrays.toString(output));
         // Get the formatted logits result
         NDArray startLogits = output[0].reshape(new int[]{0, -3});
         NDArray endLogits = output[1].reshape(new int[]{0, -3});
@@ -79,6 +80,8 @@ public class BertQA {
                 NDArray.new softmaxParam(endLogits))[0].toArray();
         int startIdx = argmax(startProb);
         int endIdx = argmax(endProb);
+       logger.info("Carin startIdx "+ startIdx);
+       logger.info("Carin endIdx "+ startIdx);
         return tokens.subList(startIdx, endIdx + 1);
     }
 

Reply via email to